import json
import re
import time
from rapidfuzz import fuzz

def register_regex(conn):
  # SQLite REGEXP support via Python UDF
  def regexp(pattern, text):
    if text is None:
      return 0
    try:
      return 1 if re.search(pattern, text, flags=re.IGNORECASE) else 0
    except re.error:
      return 0
  conn.create_function("REGEXP", 2, regexp)

def log_search(conn, mode, query, filters):
  conn.execute(
    "INSERT INTO search_history(created_at_ms, mode, query, filters_json) VALUES (?,?,?,?)",
    (int(time.time() * 1000), mode, query, json.dumps(filters or {}))
  )
  conn.commit()

def run_search(conn, mode: str, query: str, filters: dict, limit: int = 200):
  """
  modes: exact_phrase, exact_substring, regex, fuzzy, relevance
  filters: address, date_from_ms, date_to_ms, type, thread_id
  """
  where, params = build_filters(filters)

  if mode == "exact_phrase":
    # FTS phrase query using quotes
    q = f"\"{query.replace('\"','') }\""
    sql = f"""
      SELECT m.*
      FROM messages_fts f
      JOIN messages m ON m.id = f.rowid
      WHERE f.body MATCH ?
      {where}
      LIMIT ?
    """
    rows = conn.execute(sql, [q, *params, limit]).fetchall()
    return rows

  if mode == "relevance":
    sql = f"""
      SELECT m.*, bm25(messages_fts) AS rank
      FROM messages_fts
      JOIN messages m ON m.id = messages_fts.rowid
      WHERE messages_fts MATCH ?
      {where}
      ORDER BY rank ASC
      LIMIT ?
    """
    rows = conn.execute(sql, [query, *params, limit]).fetchall()
    return rows

  if mode == "exact_substring":
    sql = f"""
      SELECT m.*
      FROM messages m
      WHERE m.body LIKE ?
      {where}
      ORDER BY m.date_ms DESC
      LIMIT ?
    """
    rows = conn.execute(sql, [f"%{query}%", *params, limit]).fetchall()
    return rows

  if mode == "regex":
    # Narrow first with LIKE if possible (cheap heuristic), then REGEXP
    lit = extract_literal_hint(query)
    if lit:
      base = f"""
        SELECT m.* FROM messages m
        WHERE m.body LIKE ?
        {where} AND m.body REGEXP ?
        ORDER BY m.date_ms DESC
        LIMIT ?
      """
      p = [f"%{lit}%", *params, query, limit]
      return conn.execute(base, p).fetchall()

    base = f"""
      SELECT m.* FROM messages m
      WHERE m.body REGEXP ?
      {where}
      ORDER BY m.date_ms DESC
      LIMIT ?
    """
    return conn.execute(base, [query, *params, limit]).fetchall()

  if mode == "fuzzy":
    # Narrow with FTS relevance first to keep it fast, then fuzzy score
    cand = conn.execute(f"""
      SELECT m.*
      FROM messages_fts
      JOIN messages m ON m.id = messages_fts.rowid
      WHERE messages_fts MATCH ?
      {where}
      LIMIT 1000
    """, [query, *params]).fetchall()

    scored = []
    ql = query.lower().strip()
    for r in cand:
      bl = (r["body"] or "").lower()
      scored.append((fuzz.partial_ratio(ql, bl), r))
    scored.sort(key=lambda x: x[0], reverse=True)
    return [r for score, r in scored[:limit]]

  return []

def build_filters(filters: dict):
  filters = filters or {}
  clauses = []
  params = []

  if filters.get("address"):
    clauses.append("AND m.address = ?")
    params.append(filters["address"])
  if filters.get("type") is not None and filters.get("type") != "":
    clauses.append("AND m.type = ?")
    params.append(int(filters["type"]))
  if filters.get("thread_id"):
    clauses.append("AND m.thread_id = ?")
    params.append(int(filters["thread_id"]))
  if filters.get("date_from_ms"):
    clauses.append("AND m.date_ms >= ?")
    params.append(int(filters["date_from_ms"]))
  if filters.get("date_to_ms"):
    clauses.append("AND m.date_ms <= ?")
    params.append(int(filters["date_to_ms"]))

  where = "\n  " + "\n  ".join(clauses) if clauses else ""
  return where, params

def extract_literal_hint(regex_pattern: str):
  # naive: pull the longest plain word chunk from regex to help LIKE narrowing
  parts = re.findall(r"[A-Za-z0-9]{4,}", regex_pattern)
  return max(parts, key=len) if parts else None
