import os import argparse from PIL import Image from transformers import CLIPProcessor, CLIPModel from collections import defaultdict def classify_images(model_name, image_paths, class_names): # Load CLIP model and processor model = CLIPModel.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained(model_name) classification_results = defaultdict(list) # Perform zero-shot classification on each image for image_path in image_paths: try: image = Image.open(image_path) # Process the input image and text labels inputs = processor( text=class_names, images=image, return_tensors="pt", padding=True ) # Run the model and get logits outputs = model(**inputs) logits_per_image = outputs.logits_per_image # Calculate probabilities probs = logits_per_image.softmax(dim=1) # Get the predicted label pred_label = class_names[probs.argmax(dim=1).item()] classification_results[pred_label].append(image_path) except Exception as e: print(f"Skipping {image_path} due to error: {e}") for label, images in classification_results.items(): print(f"{label}:") for image_path in images: print(f" {image_path}") def main(): available_models = [ "openai/clip-vit-large-patch14", "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch16" ] parser = argparse.ArgumentParser(description="CLIP-based Image Classifier") parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch16", help="Model name to use for classification (default: openai/clip-vit-base-patch16)") parser.add_argument("-c", "--category", action="append",default=["image is safe for work", "image is not safe for work"],help="Add a classification category (e.g., 'man', 'woman', 'child', 'animal'). If not specified, the default categories will be 'safe for work' and 'not safe for work'.") parser.add_argument("paths", metavar="path", type=str, nargs="+", help="List of image file paths or directories") args = parser.parse_args() if args.model.lower() == "list": print("Available models:") for model in available_models: print(f" {model}") return image_paths = [] for path in args.paths: if os.path.isdir(path): image_paths.extend([os.path.join(path, file) for file in os.listdir(path)]) elif os.path.isfile(path): image_paths.append(path) else: print(f"Skipping {path}, not a valid file or directory") classify_images(args.model, image_paths, args.category) if __name__ == "__main__": main()