CLIP是OpenAI公开的,可以实现zero-shot图像分类
import sys
!{sys.executable} -m pip install git+https://github.com/openai/CLIP.git
检查GPU,加载CLIP模型:
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
import numpy as npimport osnp.set_printoptions(precision=2,suppress=True)
device = "cuda" if torch.cuda.is_available() else "cpu"model, preprocess = clip.load("ViT-B/32", device=device)
用Oxford-IIIT数据库的猫:https://www.robots.ox.ac.uk/~vgg/data/pets/
!wget https://mslearntensorflowlp.blob.core.windows.net/data/oxcats.tar.gz
!tar xfz oxcats.tar.gz
!rm oxcats.tar.gz
zero-shot图片分类:
The main thing CLIP can do is to match an image with a text prompt. So, if we take an image, say, of a cat, and then try to match it with text promps "a cat", "a penguin", "a bear" - the first one is likely to have higher probability. Thus we can conlcude that we are dealing with a cat. We don't need to train a model because it has already been pre-trained on a huge dataset - thus it is called zero-shot.
image = preprocess(Image.open("oxcats/Maine_Coon_1.jpg")).unsqueeze(0).to(device)text = clip.tokenize(["a penguin", "a bear", "a cat"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs)
输出:Label probs: [[0. 0. 1.]]
Intelligent Image Search
In the previous example, there was one image the 3 text prompts. We can use CLIP in a different context, eg. we can take many images of a cat, and then select an image the best suits the textual description:
cats_img = [ Image.open(os.path.join("oxcats",x)) for x in os.listdir("oxcats") ]
cats = torch.cat([ preprocess(i).unsqueeze(0) for i in cats_img ]).to(device)text = clip.tokenize(["a very fat gray cat"]).to(device)with torch.no_grad():
logits_per_image, logits_per_text = model(cats, text)
res = logits_per_text.softmax(dim=-1).argmax().cpu().numpy()
print("Img Index:", res)
plt.imshow(cats_img[res])