CLIP模型 - 图文对齐与检索
掌握CLIP模型,实现图文检索与零样本分类
前置知识:需要先掌握 Transformer基础
本文重点:CLIP原理与应用
一、CLIP概述
1.1 核心思想
CLIP (Contrastive Language-Image Pre-training)
核心思想:
- 图像和文本映射到同一向量空间
- 相似的图文对距离近
- 不相似的图文对距离远
训练方式:
- 对比学习 (Contrastive Learning)
- 4亿图文对预训练
- 零样本迁移能力强
1.2 使用HuggingFace
from transformers import CLIPModel, CLIPProcessor
import torch
from PIL import Image
# 加载模型
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# 图像编码
image = Image.open("image.jpg")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 文本编码
text = "a photo of a cat"
inputs = processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
text_features = model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算相似度
similarity = (image_features @ text_features.T).item()
print(f"相似度: {similarity:.4f}")
二、零样本分类
def zero_shot_classification(image, class_names):
"""零样本图像分类"""
text_prompts = [f"a photo of a {name}" for name in class_names]
inputs = processor(
text=text_prompts,
images=image,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
return {
"class": class_names[probs.argmax().item()],
"probabilities": probs[0].tolist()
}
# 使用示例
class_names = ["cat", "dog", "bird", "fish"]
image = Image.open("photo.jpg")
result = zero_shot_classification(image, class_names)
print(f"预测类别: {result['class']}")
三、图文检索
import faiss
import numpy as np
class ImageTextRetriever:
"""图文检索系统"""
def __init__(self):
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.index = None
self.image_ids = []
def encode_images(self, images):
inputs = self.processor(images=images, return_tensors="pt", padding=True)
with torch.no_grad():
features = self.model.get_image_features(**inputs)
features = features / features.norm(dim=-1, keepdim=True)
return features.numpy()
def encode_text(self, text):
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
features = features / features.norm(dim=-1, keepdim=True)
return features.numpy()
def build_index(self, images, image_ids):
features = self.encode_images(images).astype('float32')
self.index = faiss.IndexFlatIP(features.shape[1])
self.index.add(features)
self.image_ids = image_ids
def search_by_text(self, query, top_k=5):
text_features = self.encode_text(query).astype('float32')
scores, indices = self.index.search(text_features, top_k)
return [(self.image_ids[i], scores[0][j]) for j, i in enumerate(indices[0])]
参考资源
返回:多模态大模型 最后更新: 2026年4月20日
讨论与反馈