ailearn

CLIP模型 - 图文对齐与检索

掌握CLIP模型,实现图文检索与零样本分类

访问-- -- --

前置知识:需要先掌握 Transformer基础

本文重点:CLIP原理与应用


一、CLIP概述

1.1 核心思想

CLIP (Contrastive Language-Image Pre-training)
核心思想:
- 图像和文本映射到同一向量空间
- 相似的图文对距离近
- 不相似的图文对距离远
训练方式:
- 对比学习 (Contrastive Learning)
- 4亿图文对预训练
- 零样本迁移能力强

1.2 使用HuggingFace

from transformers import CLIPModel, CLIPProcessor
import torch
from PIL import Image
# 加载模型
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# 图像编码
image = Image.open("image.jpg")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
    image_features = model.get_image_features(**inputs)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 文本编码
text = "a photo of a cat"
inputs = processor(text=[text], return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = model.get_text_features(**inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算相似度
similarity = (image_features @ text_features.T).item()
print(f"相似度: {similarity:.4f}")

二、零样本分类

def zero_shot_classification(image, class_names):
    """零样本图像分类"""
    text_prompts = [f"a photo of a {name}" for name in class_names]
    
    inputs = processor(
        text=text_prompts,
        images=image,
        return_tensors="pt",
        padding=True
    )
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
    
    return {
        "class": class_names[probs.argmax().item()],
        "probabilities": probs[0].tolist()
    }
# 使用示例
class_names = ["cat", "dog", "bird", "fish"]
image = Image.open("photo.jpg")
result = zero_shot_classification(image, class_names)
print(f"预测类别: {result['class']}")

三、图文检索

import faiss
import numpy as np
class ImageTextRetriever:
    """图文检索系统"""
    
    def __init__(self):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.index = None
        self.image_ids = []
    
    def encode_images(self, images):
        inputs = self.processor(images=images, return_tensors="pt", padding=True)
        with torch.no_grad():
            features = self.model.get_image_features(**inputs)
            features = features / features.norm(dim=-1, keepdim=True)
        return features.numpy()
    
    def encode_text(self, text):
        inputs = self.processor(text=[text], return_tensors="pt", padding=True)
        with torch.no_grad():
            features = self.model.get_text_features(**inputs)
            features = features / features.norm(dim=-1, keepdim=True)
        return features.numpy()
    
    def build_index(self, images, image_ids):
        features = self.encode_images(images).astype('float32')
        self.index = faiss.IndexFlatIP(features.shape[1])
        self.index.add(features)
        self.image_ids = image_ids
    
    def search_by_text(self, query, top_k=5):
        text_features = self.encode_text(query).astype('float32')
        scores, indices = self.index.search(text_features, top_k)
        return [(self.image_ids[i], scores[0][j]) for j, i in enumerate(indices[0])]

参考资源


返回多模态大模型 最后更新: 2026年4月20日

访问 --

讨论与反馈