Skip to content

Instantly share code, notes, and snippets.

@zucchini-nlp
Last active November 25, 2024 09:57
Show Gist options
  • Save zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042 to your computer and use it in GitHub Desktop.
Save zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042 to your computer and use it in GitHub Desktop.
Update BLIP-2 model for new version
# Load your model and processor and run the following to update BLIP-2 model
# It will update file in your repo by adding new args in configs and resizing embedding layer
# Then you'll be able to run BLIP-2 without warnings/errors
from transformers import AddedToken
processor.num_query_tokens = model.config.num_query_tokens
image_token = AddedToken("<image>", normalized=False, special=True)
processor.tokenizer.add_tokens([image_token], special_tokens=True)
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64) # pad for efficient computation
model.config.image_token_index = len(processor.tokenizer) - 1
model.push_to_hub("YOUR-REPO")
processor.push_to_hub("YOUR-REPO")
@pspdada
Copy link

pspdada commented Oct 12, 2024

I tried to use the model Salesforce/blip2-itm-vit-g, but encountered a warning. After adding the code mentioned here, I received a NotImplementedError when calling the function resize_token_embeddings. What should I do?

@zucchini-nlp
Copy link
Author

hmm weird, lemme check quickly

@pspdada
Copy link

pspdada commented Oct 17, 2024

I encounter a new issue about it, please check it.
huggingface/transformers#34223

@pspdada
Copy link

pspdada commented Oct 28, 2024

I found that after using this code, a warning also appears in resize_token_embeddings: "The new embeddings will be initialized from a multivariate normal distribution that has the old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use mean_resizing=False."
Do you know how to handle it?

@zucchini-nlp
Copy link
Author

Yes, we added two different modes for resizing and initalizing from mean is exactly what we want to not mess up the distribution in embedding matrix thus causing gibberish logits/generation. We don't have to handle that and set mean_resizing=False, I guess the warning is there simply to say that resizing technique has changed compared to the prev release

@SHYuanBest
Copy link

SHYuanBest commented Nov 2, 2024

so what modification is needed to be correct? for transformers==4.46.1

@zucchini-nlp
Copy link
Author

@SHYuanBest the gist should be correct with the latest transformers? Does that throw errors for you?

@SHYuanBest
Copy link

Yes, there is still a warning jump here.

"Expanding inputs for image tokens in BLIP-2 should be done in processing. Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. Using processors without these attributes in the config is deprecated and will throw an error in v4.47."

@zucchini-nlp
Copy link
Author

@SHYuanBest i just tried the following script with the latest main branch and ut gives no warnings anymore. Can you share the reproduction script and the version of transformers you are using?

from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AddedToken
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
)


processor.num_query_tokens = model.config.num_query_tokens
image_token = AddedToken("<image>", normalized=False, special=True)
processor.tokenizer.add_tokens([image_token], special_tokens=True)

model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64) # pad for efficient computation
model.config.image_token_index = len(processor.tokenizer) - 1

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = "Question: how many cats are there? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

@SHYuanBest
Copy link

SHYuanBest commented Nov 4, 2024

transformers == 4.46.1

from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForImageTextRetrieval, AddedToken

blip_model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16)
blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-itm-vit-g")
    
blip_processor.num_query_tokens = blip_model.config.num_query_tokens
image_token = AddedToken("<image>", normalized=False, special=True)
blip_processor.tokenizer.add_tokens([image_token], special_tokens=True)

blip_model.resize_token_embeddings(len(blip_processor.tokenizer), pad_to_multiple_of=64) # pad for efficient computation
blip_model.config.image_token_index = len(blip_processor.tokenizer) - 1
Traceback (most recent call last):
  File "get_clip_and_blip_score.py", line 56, in <module>
    blip_model.resize_token_embeddings(len(blip_processor.tokenizer), pad_to_multiple_of=64) # pad for efficient computation
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/transformers/modeling_utils.py", line 2117, in resize_token_embeddings
    model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/transformers/modeling_utils.py", line 2141, in _resize_token_embeddings
    old_embeddings = self.get_input_embeddings()
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/transformers/modeling_utils.py", line 1873, in get_input_embeddings
    raise NotImplementedError
NotImplementedError

@SHYuanBest
Copy link

SHYuanBest commented Nov 4, 2024

oh, I install the latest main branch transformer using pip install -e . and the problem was solved.

@zucchini-nlp
Copy link
Author

@SHYuanBest no, you should not. It is a warning not related to these changes and I guess it was there only because new resizing uses a different method than befroe

@SHYuanBest
Copy link

SHYuanBest commented Nov 4, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
OSZAR »