多模态 - CLIP实践

edwin99
edwin99
2024-02-09 23:19
26 阅读
0 评论
文章封面
目录
正在加载目录...

CLIP是OpenAI公开的,可以实现zero-shot图像分类

 

import sys

!{sys.executable} -m pip install git+https://github.com/openai/CLIP.git

 

检查GPU,加载CLIP模型:

import torch

import clip

from PIL import Image

import matplotlib.pyplot as plt

import numpy as npimport osnp.set_printoptions(precision=2,suppress=True)

 

device = "cuda" if torch.cuda.is_available() else "cpu"model, preprocess = clip.load("ViT-B/32", device=device)

用Oxford-IIIT数据库的猫:https://www.robots.ox.ac.uk/~vgg/data/pets/

!wget https://mslearntensorflowlp.blob.core.windows.net/data/oxcats.tar.gz

!tar xfz oxcats.tar.gz

!rm oxcats.tar.gz

 

zero-shot图片分类:

The main thing CLIP can do is to match an image with a text prompt. So, if we take an image, say, of a cat, and then try to match it with text promps "a cat", "a penguin", "a bear" - the first one is likely to have higher probability. Thus we can conlcude that we are dealing with a cat. We don't need to train a model because it has already been pre-trained on a huge dataset - thus it is called zero-shot.

 

image = preprocess(Image.open("oxcats/Maine_Coon_1.jpg")).unsqueeze(0).to(device)text = clip.tokenize(["a penguin", "a bear", "a cat"]).to(device)

 

with torch.no_grad():

image_features = model.encode_image(image)

text_features = model.encode_text(text)

logits_per_image, logits_per_text = model(image, text)

probs = logits_per_image.softmax(dim=-1).cpu().numpy()

 

print("Label probs:", probs)

输出:Label probs: [[0. 0. 1.]]

 

Intelligent Image Search

In the previous example, there was one image the 3 text prompts. We can use CLIP in a different context, eg. we can take many images of a cat, and then select an image the best suits the textual description:

 

cats_img = [ Image.open(os.path.join("oxcats",x)) for x in os.listdir("oxcats") ]

cats = torch.cat([ preprocess(i).unsqueeze(0) for i in cats_img ]).to(device)text = clip.tokenize(["a very fat gray cat"]).to(device)with torch.no_grad():

logits_per_image, logits_per_text = model(cats, text)

res = logits_per_text.softmax(dim=-1).argmax().cpu().numpy()

 

print("Img Index:", res)

 

plt.imshow(cats_img[res])

 

 

评论区 (0)

登录后参与评论

暂无评论,抢沙发吧!