added imagezeroshotcategory
This commit is contained in:
84
codegrab/vqa3.py
Normal file
84
codegrab/vqa3.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import argparse
|
||||
from PIL import Image
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def classify_images(model_name, image_paths, class_names):
|
||||
# Load CLIP model and processor
|
||||
model = CLIPModel.from_pretrained(model_name)
|
||||
processor = CLIPProcessor.from_pretrained(model_name)
|
||||
|
||||
classification_results = defaultdict(list)
|
||||
|
||||
# Perform zero-shot classification on each image
|
||||
for image_path in image_paths:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# Process the input image and text labels
|
||||
inputs = processor(
|
||||
text=class_names,
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
padding=True
|
||||
)
|
||||
|
||||
# Run the model and get logits
|
||||
outputs = model(**inputs)
|
||||
logits_per_image = outputs.logits_per_image
|
||||
|
||||
# Calculate probabilities
|
||||
probs = logits_per_image.softmax(dim=1)
|
||||
|
||||
# Get the predicted label
|
||||
pred_label = class_names[probs.argmax(dim=1).item()]
|
||||
|
||||
classification_results[pred_label].append(image_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Skipping {image_path} due to error: {e}")
|
||||
|
||||
for label, images in classification_results.items():
|
||||
print(f"{label}:")
|
||||
for image_path in images:
|
||||
print(f" {image_path}")
|
||||
|
||||
|
||||
def main():
|
||||
available_models = [
|
||||
"openai/clip-vit-large-patch14",
|
||||
"openai/clip-vit-base-patch32",
|
||||
"openai/clip-vit-base-patch16"
|
||||
]
|
||||
|
||||
parser = argparse.ArgumentParser(description="CLIP-based Image Classifier")
|
||||
parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch16",
|
||||
help="Model name to use for classification (default: openai/clip-vit-base-patch16)")
|
||||
parser.add_argument("-c", "--category", action="append",default=["image is safe for work", "image is not safe for work"],help="Add a classification category (e.g., 'man', 'woman', 'child', 'animal'). If not specified, the default categories will be 'safe for work' and 'not safe for work'.")
|
||||
parser.add_argument("paths", metavar="path", type=str, nargs="+",
|
||||
help="List of image file paths or directories")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model.lower() == "list":
|
||||
print("Available models:")
|
||||
for model in available_models:
|
||||
print(f" {model}")
|
||||
return
|
||||
|
||||
image_paths = []
|
||||
for path in args.paths:
|
||||
if os.path.isdir(path):
|
||||
image_paths.extend([os.path.join(path, file) for file in os.listdir(path)])
|
||||
elif os.path.isfile(path):
|
||||
image_paths.append(path)
|
||||
else:
|
||||
print(f"Skipping {path}, not a valid file or directory")
|
||||
|
||||
classify_images(args.model, image_paths, args.category)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user