#!/usr/bin/env python3

"""
`what` - README-driven repository search using Ollama only.

Usage:
    what <query>              # Find tools matching a natural-language query
    what -l                   # List catalogued tools
    what --model <model> ...  # Override the default Ollama model
"""

from __future__ import annotations

import argparse
import json
import os
import re
import subprocess
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).parent.resolve()
README_PATH = REPO_ROOT / "README.md"
DEFAULT_MODEL = os.environ.get("WHAT_OLLAMA_MODEL", "gemma4")
CATALOG_HEADING = "## Tool Catalog"
ENTRY_RE = re.compile(
    r"^- `([^`]+)` \| goal: (.*?) \| usage: (.*)$"
)
TOKEN_RE = re.compile(r"[a-z0-9_.+-]+")


class WhatError(Exception):
    pass


def load_readme() -> str:
    if not README_PATH.exists():
        raise WhatError(f"README not found at {README_PATH}")
    return README_PATH.read_text(encoding="utf-8")


def extract_catalog(readme_text: str) -> list[dict[str, str]]:
    in_catalog = False
    entries: list[dict[str, str]] = []

    for raw_line in readme_text.splitlines():
        line = raw_line.rstrip()

        if line == CATALOG_HEADING:
            in_catalog = True
            continue

        if in_catalog and line.startswith("## "):
            break

        if not in_catalog:
            continue

        match = ENTRY_RE.match(line)
        if not match:
            continue

        path, goal, usage = match.groups()
        entries.append(
            {
                "path": path,
                "goal": goal.strip(),
                "usage": usage.strip(),
            }
        )

    if not entries:
        raise WhatError(
            "No tool catalog entries found in README. "
            f"Expected entries under '{CATALOG_HEADING}'."
        )

    return entries


def ensure_ollama_available(model: str) -> None:
    if not shutil_which("ollama"):
        raise WhatError("`ollama` is not installed or not in PATH.")

    try:
        result = subprocess.run(
            ["ollama", "list"],
            capture_output=True,
            text=True,
            timeout=10,
            check=False,
        )
    except subprocess.SubprocessError as exc:
        raise WhatError(f"Failed to talk to Ollama: {exc}") from exc

    if result.returncode != 0:
        stderr = result.stderr.strip() or "unknown error"
        raise WhatError(f"Ollama is unavailable: {stderr}")

    models = result.stdout.lower()
    if model.lower() not in models:
        raise WhatError(
            f"Model '{model}' is not available locally. "
            "Pull it first with `ollama pull ...`."
        )


def shutil_which(binary: str) -> str | None:
    for directory in os.environ.get("PATH", "").split(os.pathsep):
        candidate = Path(directory) / binary
        if candidate.is_file() and os.access(candidate, os.X_OK):
            return str(candidate)
    return None


def build_prompt(query: str, entries: list[dict[str, str]]) -> str:
    catalog_lines = [
        f'- {entry["path"]} | goal: {entry["goal"]} | usage: {entry["usage"]}'
        for entry in entries
    ]
    catalog = "\n".join(catalog_lines)

    return f"""You are selecting tools from a repository catalog.
Use only the catalog below. Prefer direct matches. Use archived tools only if they clearly fit the request.

Return strict JSON matching this schema exactly:
{{
  "results": [
    {{
      "path": "exact catalog path",
      "reason": "one short sentence explaining why this tool matches"
    }}
  ]
}}

Constraints:
- The "results" array must contain up to 8 objects.
- Do not invent paths.
- Prefer the entry whose action best matches the query: compare beats hash for comparison queries, open beats convert for opening queries, and mount beats inspect for mount queries.

Query: {query}

Catalog:
{catalog}
"""


def tokenize(text: str) -> set[str]:
    return set(TOKEN_RE.findall(text.lower()))


def shortlist_entries(query: str, entries: list[dict[str, str]], limit: int = 100) -> list[dict[str, str]]:
    query_tokens = tokenize(query)
    if not query_tokens:
        return entries[:limit]

    scored: list[tuple[int, dict[str, str]]] = []
    for entry in entries:
        haystack = f'{entry["path"]} {entry["goal"]} {entry["usage"]}'.lower()
        entry_tokens = tokenize(haystack)
        overlap = len(query_tokens & entry_tokens)
        substring_hits = sum(1 for token in query_tokens if token in haystack)
        archive_penalty = 1 if entry["path"].startswith("archive/") else 0
        score = overlap * 5 + substring_hits - archive_penalty
        scored.append((score, entry))

    scored.sort(key=lambda item: item[0], reverse=True)
    best = [entry for score, entry in scored if score > 0][:limit]
    return best or entries[:limit]


def extract_json_array(output: str) -> list[dict[str, str]]:
    # Step 1: Clean and find the root object boundary if Ollama prefixes anything
    match = re.search(r"\{\s*.*\}\s*", output, re.DOTALL)
    payload = match.group(0) if match else output

    try:
        # ALLOW literal newlines/control characters inside string properties
        data = json.loads(payload, strict=False)
    except json.JSONDecodeError as exc:
        raise WhatError(f"Failed to parse model output as JSON: {exc}")

    if not isinstance(data, dict):
        raise WhatError("Model output must be a root JSON object.")

    # Step 2: Safe navigation into the expected schema array
    results_list = data.get("results")
    if results_list is None:
        raise WhatError("Missing 'results' key in model JSON response.")
        
    if not isinstance(results_list, list):
        raise WhatError("The 'results' property must be a JSON array.")

    # Step 3: Extract and normalize items
    normalized: list[dict[str, str]] = []
    for item in results_list:
        if not isinstance(item, dict):
            continue
        path = str(item.get("path", "")).strip()
        # Clean up any literal newlines the model injected into the text
        reason = str(item.get("reason", "")).replace("\n", " ").strip()
        if path:
            normalized.append({"path": path, "reason": reason})
            
    return normalized


def run_ollama_once(prompt: str, model: str) -> str:
    try:
        result = subprocess.run(
            ["ollama", "run", "--format", "json", "--hidethinking", model, prompt],
            capture_output=True,
            text=True,
            timeout=60,
            check=False,
        )
    except subprocess.SubprocessError as exc:
        raise WhatError(f"Ollama run failed: {exc}") from exc

    if result.returncode != 0:
        stderr = result.stderr.strip() or "unknown error"
        raise WhatError(f"Ollama run failed: {stderr}")

    return result.stdout.strip()


def run_ollama(prompt: str, model: str) -> list[dict[str, str]]:
    first_output = run_ollama_once(prompt, model)
    try:
        return extract_json_array(first_output)
    except (json.JSONDecodeError, WhatError):
        repair_prompt = (
            "Rewrite the following response as strict JSON matching the target schema.\n"
            "Target format:\n"
            '{\n  "results": [{"path":"exact catalog path","reason":"short reason"}]\n}\n'
            "Do not add markdown or commentary.\n\n"
            f"Response to repair:\n{first_output}\n"
        )
        repaired_output = run_ollama_once(repair_prompt, model)
        try:
            return extract_json_array(repaired_output)
        except (json.JSONDecodeError, WhatError) as exc:
            raise WhatError(
                "Model output was not valid JSON after repair. "
                f"Raw output was:\n{repaired_output}"
            ) from exc


def search(query: str, entries: list[dict[str, str]], model: str) -> list[dict[str, str]]:
    ensure_ollama_available(model)
    prompt_entries = shortlist_entries(query, entries)
    raw_results = run_ollama(build_prompt(query, prompt_entries), model)
    entry_map = {entry["path"]: entry for entry in entries}

    results: list[dict[str, str]] = []
    seen: set[str] = set()
    for item in raw_results:
        path = item["path"]
        if path not in entry_map or path in seen:
            continue
        seen.add(path)
        merged = dict(entry_map[path])
        merged["reason"] = item.get("reason", "")
        results.append(merged)
    return results


def list_entries(entries: list[dict[str, str]]) -> None:
    for entry in entries:
        print(f'{entry["path"]}')
        print(f'  goal:  {entry["goal"]}')
        print(f'  usage: {entry["usage"]}')


def show_results(query: str, results: list[dict[str, str]], model: str) -> None:
    if not results:
        print(f"No catalogued tool matched: {query}")
        return

    print(f"Model: {model}")
    print(f"Query: {query}")
    print()

    for idx, item in enumerate(results, 1):
        print(f"{idx}. {item['path']}")
        print(f"   Goal: {item['goal']}")
        print(f"   Usage: {item['usage']}")
        if item.get("reason"):
            print(f"   Why: {item['reason']}")
        print()


def main() -> int:
    parser = argparse.ArgumentParser(description="README-driven repository search using Ollama")
    parser.add_argument("query", nargs="?", help="Natural-language search query")
    parser.add_argument("-l", "--list", action="store_true", help="List catalogued tools")
    parser.add_argument("--model", default=DEFAULT_MODEL, help=f"Ollama model to use (default: {DEFAULT_MODEL})")
    args = parser.parse_args()

    try:
        entries = extract_catalog(load_readme())
    except WhatError as exc:
        print(f"Error: {exc}", file=sys.stderr)
        return 1

    if args.list:
        list_entries(entries)
        return 0

    if not args.query:
        parser.print_help()
        print()
        print(f"Catalog source: {README_PATH}")
        print(f"Default model: {args.model}")
        return 0

    try:
        results = search(args.query, entries, args.model)
    except WhatError as exc:
        print(f"Error: {exc}", file=sys.stderr)
        return 1

    show_results(args.query, results, args.model)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())