85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
import os
|
|
import argparse
|
|
from PIL import Image
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
from collections import defaultdict
|
|
|
|
|
|
def classify_images(model_name, image_paths, class_names):
|
|
# Load CLIP model and processor
|
|
model = CLIPModel.from_pretrained(model_name)
|
|
processor = CLIPProcessor.from_pretrained(model_name)
|
|
|
|
classification_results = defaultdict(list)
|
|
|
|
# Perform zero-shot classification on each image
|
|
for image_path in image_paths:
|
|
try:
|
|
image = Image.open(image_path)
|
|
|
|
# Process the input image and text labels
|
|
inputs = processor(
|
|
text=class_names,
|
|
images=image,
|
|
return_tensors="pt",
|
|
padding=True
|
|
)
|
|
|
|
# Run the model and get logits
|
|
outputs = model(**inputs)
|
|
logits_per_image = outputs.logits_per_image
|
|
|
|
# Calculate probabilities
|
|
probs = logits_per_image.softmax(dim=1)
|
|
|
|
# Get the predicted label
|
|
pred_label = class_names[probs.argmax(dim=1).item()]
|
|
|
|
classification_results[pred_label].append(image_path)
|
|
|
|
except Exception as e:
|
|
print(f"Skipping {image_path} due to error: {e}")
|
|
|
|
for label, images in classification_results.items():
|
|
print(f"{label}:")
|
|
for image_path in images:
|
|
print(f" {image_path}")
|
|
|
|
|
|
def main():
|
|
available_models = [
|
|
"openai/clip-vit-large-patch14",
|
|
"openai/clip-vit-base-patch32",
|
|
"openai/clip-vit-base-patch16"
|
|
]
|
|
|
|
parser = argparse.ArgumentParser(description="CLIP-based Image Classifier")
|
|
parser.add_argument("--model", type=str, default="openai/clip-vit-base-patch16",
|
|
help="Model name to use for classification (default: openai/clip-vit-base-patch16)")
|
|
parser.add_argument("-c", "--category", action="append",default=["image is safe for work", "image is not safe for work"],help="Add a classification category (e.g., 'man', 'woman', 'child', 'animal'). If not specified, the default categories will be 'safe for work' and 'not safe for work'.")
|
|
parser.add_argument("paths", metavar="path", type=str, nargs="+",
|
|
help="List of image file paths or directories")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.model.lower() == "list":
|
|
print("Available models:")
|
|
for model in available_models:
|
|
print(f" {model}")
|
|
return
|
|
|
|
image_paths = []
|
|
for path in args.paths:
|
|
if os.path.isdir(path):
|
|
image_paths.extend([os.path.join(path, file) for file in os.listdir(path)])
|
|
elif os.path.isfile(path):
|
|
image_paths.append(path)
|
|
else:
|
|
print(f"Skipping {path}, not a valid file or directory")
|
|
|
|
classify_images(args.model, image_paths, args.category)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|