Files
gists/codegrab/vqa3.py
2023-03-21 16:18:50 +01:00

85 lines
2.8 KiB
Python

import os
import argparse
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from collections import defaultdict
def classify_images(model_name, image_paths, class_names):
# Load CLIP model and processor
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
classification_results = defaultdict(list)
# Perform zero-shot classification on each image
for image_path in image_paths:
try:
image = Image.open(image_path)
# Process the input image and text labels
inputs = processor(
text=class_names,
images=image,
return_tensors="pt",
padding=True
)
# Run the model and get logits
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
# Calculate probabilities
probs = logits_per_image.softmax(dim=1)
# Get the predicted label
pred_label = class_names[probs.argmax(dim=1).item()]
classification_results[pred_label].append(image_path)
except Exception as e:
print(f"Skipping {image_path} due to error: {e}")
for label, images in classification_results.items():
print(f"{label}:")
for image_path in images:
print(f" {image_path}")
def main():
available_models = [
"openai/clip-vit-large-patch14",
"openai/clip-vit-base-patch32",
"openai/clip-vit-base-patch16"
]
parser = argparse.ArgumentParser(description="CLIP-based Image Classifier")
parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch16",
help="Model name to use for classification (default: openai/clip-vit-base-patch16)")
parser.add_argument("-c", "--category", action="append",default=["image is safe for work", "image is not safe for work"],help="Add a classification category (e.g., 'man', 'woman', 'child', 'animal'). If not specified, the default categories will be 'safe for work' and 'not safe for work'.")
parser.add_argument("paths", metavar="path", type=str, nargs="+",
help="List of image file paths or directories")
args = parser.parse_args()
if args.model.lower() == "list":
print("Available models:")
for model in available_models:
print(f" {model}")
return
image_paths = []
for path in args.paths:
if os.path.isdir(path):
image_paths.extend([os.path.join(path, file) for file in os.listdir(path)])
elif os.path.isfile(path):
image_paths.append(path)
else:
print(f"Skipping {path}, not a valid file or directory")
classify_images(args.model, image_paths, args.category)
if __name__ == "__main__":
main()