

from pathlib import Path
from typing import Dict, List

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.feature_extraction.text import TfidfVectorizer, ENGLISH_STOP_WORDS

from ..scraping.utils import get_project_root


CUSTOM_STOPWORDS: List[str] = [

    "budget",
    "budgets",
    "2025",
    "2024",
    "2023",
    "budget 2025",

    "live",
    "update",
    "updates",
    "reaction",
    "reactions",
    "latest",
    "breaking",
    "analysis",
    "special",
    "podcast",
    "show",
    "watch",
    "listen",

    "politics",
    "news",
    "uk",
]


def load_headlines() -> pd.DataFrame:

    root = get_project_root()
    csv_path = root / "data" / "results" / "headlines_with_frames.csv"
    print(f"[NLP] Loading headlines from: {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"[NLP] Loaded {len(df)} rows.")
    return df


def build_corpus(df: pd.DataFrame) -> Dict[str, str]:

    corpus: Dict[str, str] = {}

    for src, sub in df.groupby("source"):
        texts = [str(t) for t in sub["headline"].dropna()]
        doc = " . ".join(texts)
        corpus[src] = doc

    print("[NLP] Built corpus for sources:", list(corpus.keys()))
    return corpus


def make_vectorizer() -> TfidfVectorizer:

    combined_stopwords = set(ENGLISH_STOP_WORDS) | {w.lower() for w in CUSTOM_STOPWORDS}

    vectorizer = TfidfVectorizer(
        lowercase=True,
        stop_words=list(combined_stopwords),
        ngram_range=(1, 2),                   # unigram + bigram
        min_df=1,
    )
    return vectorizer


def extract_top_keywords(corpus: Dict[str, str], top_k: int = 5) -> pd.DataFrame:
    """
    对每个媒体提取 top_k 个 TF-IDF 最高的 term。
    返回 DataFrame: columns = [source, term, tfidf]
    """
    sources = list(corpus.keys())
    docs = [corpus[src] for src in sources]

    vectorizer = make_vectorizer()
    X = vectorizer.fit_transform(docs)
    terms = vectorizer.get_feature_names_out()

    rows: List[Dict[str, object]] = []

    for i, src in enumerate(sources):
        row = X[i].toarray().ravel()
  
        top_idx = row.argsort()[::-1][:top_k]

        for idx in top_idx:
            term = terms[idx]
            score = float(row[idx])

  
            t_clean = term.strip()
            if len(t_clean) < 3:
                continue
            if t_clean.replace(" ", "").isdigit():
                continue

            rows.append({
                "source": src,
                "term": t_clean,
                "tfidf": score,
            })

    df_keywords = pd.DataFrame(rows)
    print("[NLP] Extracted keywords:\n", df_keywords)
    return df_keywords


def plot_keywords(df_keywords: pd.DataFrame) -> None:

    root = get_project_root()
    out_dir = root / "data" / "results" / "plots"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "nlp_keywords_by_source.png"


    df_plot = df_keywords.sort_values("tfidf", ascending=False).reset_index(drop=True)

    plt.figure(figsize=(12, 6))

    sources = df_plot["source"].unique()
    x = range(len(df_plot))

   
    for src in sources:
        sub = df_plot[df_plot["source"] == src]
        idx = sub.index
        plt.bar(idx, sub["tfidf"], label=src)

    plt.xticks(
        ticks=range(len(df_plot)),
        labels=df_plot["term"],
        rotation=45,
        ha="right"
    )
    plt.ylabel("TF-IDF score")
    plt.title("Top keywords by source (TF-IDF, custom stopwords)")
    plt.legend()
    plt.tight_layout()

    plt.savefig(out_path)
    plt.close()

    print(f"[NLP] Saved keyword plot -> {out_path}")


def main() -> None:
    df = load_headlines()
    corpus = build_corpus(df)
    df_keywords = extract_top_keywords(corpus, top_k=5)
    plot_keywords(df_keywords)
    print("[NLP] Done.")


if __name__ == "__main__":
    main()
