diff --git a/codegrab/vqa3.py b/codegrab/vqa3.py new file mode 100644 index 0000000..25d38ba --- /dev/null +++ b/codegrab/vqa3.py @@ -0,0 +1,84 @@ +import os +import argparse +from PIL import Image +from transformers import CLIPProcessor, CLIPModel +from collections import defaultdict + + +def classify_images(model_name, image_paths, class_names): + # Load CLIP model and processor + model = CLIPModel.from_pretrained(model_name) + processor = CLIPProcessor.from_pretrained(model_name) + + classification_results = defaultdict(list) + + # Perform zero-shot classification on each image + for image_path in image_paths: + try: + image = Image.open(image_path) + + # Process the input image and text labels + inputs = processor( + text=class_names, + images=image, + return_tensors="pt", + padding=True + ) + + # Run the model and get logits + outputs = model(**inputs) + logits_per_image = outputs.logits_per_image + + # Calculate probabilities + probs = logits_per_image.softmax(dim=1) + + # Get the predicted label + pred_label = class_names[probs.argmax(dim=1).item()] + + classification_results[pred_label].append(image_path) + + except Exception as e: + print(f"Skipping {image_path} due to error: {e}") + + for label, images in classification_results.items(): + print(f"{label}:") + for image_path in images: + print(f" {image_path}") + + +def main(): + available_models = [ + "openai/clip-vit-large-patch14", + "openai/clip-vit-base-patch32", + "openai/clip-vit-base-patch16" + ] + + parser = argparse.ArgumentParser(description="CLIP-based Image Classifier") + parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch16", + help="Model name to use for classification (default: openai/clip-vit-base-patch16)") + parser.add_argument("-c", "--category", action="append",default=["image is safe for work", "image is not safe for work"],help="Add a classification category (e.g., 'man', 'woman', 'child', 'animal'). If not specified, the default categories will be 'safe for work' and 'not safe for work'.") + parser.add_argument("paths", metavar="path", type=str, nargs="+", + help="List of image file paths or directories") + + args = parser.parse_args() + + if args.model.lower() == "list": + print("Available models:") + for model in available_models: + print(f" {model}") + return + + image_paths = [] + for path in args.paths: + if os.path.isdir(path): + image_paths.extend([os.path.join(path, file) for file in os.listdir(path)]) + elif os.path.isfile(path): + image_paths.append(path) + else: + print(f"Skipping {path}, not a valid file or directory") + + classify_images(args.model, image_paths, args.category) + + +if __name__ == "__main__": + main()