

from pathlib import Path
from typing import List

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from ..scraping.utils import get_project_root


MEDIA_ORDER: List[str] = ["bbc", "guardian", "sky", "telegraph", "ft"]


def load_data():
    root: Path = get_project_root()

    framing_path = root / "data" / "results" / "framing_summary.csv"
    sentiment_path = root / "data" / "results" / "sentiment_summary.csv"

    print(f"[Compare] Loading framing from:   {framing_path}")
    print(f"[Compare] Loading sentiment from: {sentiment_path}")

    framing = pd.read_csv(framing_path)
    sentiment = pd.read_csv(sentiment_path)

    return framing, sentiment

def plot_framing_grouped(framing_df: pd.DataFrame) -> None:
    root = get_project_root()
    out_dir = root / "data" / "results" / "plots"
    out_dir.mkdir(exist_ok=True, parents=True)

    framing = framing_df[framing_df["source"].isin(MEDIA_ORDER)].copy()

    pivot = (
        framing.pivot_table(index="frame", columns="source", values="count", fill_value=0)
        .reindex(columns=MEDIA_ORDER)
        .sort_index()
    )

    frames = pivot.index.tolist()
    x = range(len(frames))

    plt.figure(figsize=(10, 5))

    total_width = 0.8
    n_media = len(MEDIA_ORDER)
    bar_width = total_width / n_media

    for i, src in enumerate(MEDIA_ORDER):
        counts = pivot[src].values
        x_pos = [v + (i - n_media / 2) * bar_width + bar_width / 2 for v in x]
        plt.bar(x_pos, counts, width=bar_width, label=src)

    plt.xticks(list(x), frames, rotation=30, ha="right")
    plt.ylabel("Number of headlines")
    plt.title("Framing comparison by frame (5 media)")
    plt.legend()
    plt.tight_layout()

    out_path = out_dir / "framing_grouped_by_frame.png"
    plt.savefig(out_path)
    plt.close()
    print(f"[Compare] Saved framing grouped plot -> {out_path}")


def plot_framing_percentage(framing_df: pd.DataFrame) -> None:
    root = get_project_root()
    out_dir = root / "data" / "results" / "plots"
    out_dir.mkdir(exist_ok=True, parents=True)

    framing = framing_df[framing_df["source"].isin(MEDIA_ORDER)].copy()

    totals = framing.groupby("source")["count"].sum().rename("total")
    framing = framing.merge(totals, on="source")
    framing["percentage"] = framing["count"] / framing["total"] * 100.0

    pivot = framing.pivot_table(index="source", columns="frame", values="percentage", fill_value=0)
    pivot = pivot.reindex(index=MEDIA_ORDER)

    sources = pivot.index.tolist()
    x = range(len(sources))

    plt.figure(figsize=(8, 5))

    bottom = [0.0] * len(sources)
    for frame in pivot.columns:
        vals = pivot[frame].values
        plt.bar(x, vals, bottom=bottom, label=frame)
        bottom = [b + v for b, v in zip(bottom, vals)]

    plt.xticks(list(x), sources)
    plt.ylabel("Percentage (%)")
    plt.title("Frame composition within each media (percentage)")
    plt.legend(title="Frame", bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()

    out_path = out_dir / "framing_percentage_by_source.png"
    plt.savefig(out_path)
    plt.close()
    print(f"[Compare] Saved framing percentage plot -> {out_path}")


def plot_sentiment(sentiment_df: pd.DataFrame) -> None:
    root = get_project_root()
    out_dir = root / "data" / "results" / "plots"
    out_dir.mkdir(exist_ok=True, parents=True)

    sentiment = sentiment_df[sentiment_df["source"].isin(MEDIA_ORDER)].copy()

    sentiment_order = ["negative", "neutral", "positive"]

    pivot = sentiment.pivot_table(index="source", columns="sentiment", values="count", fill_value=0)
    pivot = pivot.reindex(index=MEDIA_ORDER)

    for sent in sentiment_order:
        if sent not in pivot.columns:
            pivot[sent] = 0

    pivot = pivot[sentiment_order]

    sources = pivot.index.tolist()
    x = range(len(sources))

    plt.figure(figsize=(8, 5))

    bottom = [0] * len(sources)
    for sent in sentiment_order:
        vals = pivot[sent].values
        plt.bar(x, vals, bottom=bottom, label=sent)
        bottom = [b + v for b, v in zip(bottom, vals)]

    plt.xticks(list(x), sources)
    plt.ylabel("Number of headlines")
    plt.title("Sentiment distribution by media")
    plt.legend(title="Sentiment")
    plt.tight_layout()

    out_path = out_dir / "sentiment_by_source.png"
    plt.savefig(out_path)
    plt.close()
    print(f"[Compare] Saved sentiment plot -> {out_path}")


def plot_frame_sentiment_heatmap(headlines_df: pd.DataFrame) -> None:
    root = get_project_root()
    out_dir = root / "data" / "results" / "plots"
    out_dir.mkdir(exist_ok=True, parents=True)

    df = headlines_df.copy()
    df = df[df["source"].isin(MEDIA_ORDER)]

    sentiment_order = ["negative", "neutral", "positive"]
    frames = sorted(df["frame"].unique())

    if "sentiment" in df.columns:
        df["sentiment_label"] = df["sentiment"]

    for src in MEDIA_ORDER:
        sub = df[df["source"] == src]

        pivot = (
            sub.groupby(["frame", "sentiment_label"])
            .size()
            .reset_index(name="count")
            .pivot(index="frame", columns="sentiment_label", values="count")
            .reindex(index=frames, columns=sentiment_order, fill_value=0)
        )

        pivot = pivot.fillna(0).astype(int)

        plt.figure(figsize=(7, 5))
        sns.heatmap(
            pivot,
            annot=True,
            fmt="d",
            cmap="YlOrRd",
            cbar=True,
        )

        plt.title(f"Frame × Sentiment Heatmap — {src}")
        plt.ylabel("Frame")
        plt.xlabel("Sentiment")
        plt.tight_layout()

        out_path = out_dir / f"{src}_frame_sentiment_heatmap.png"
        plt.savefig(out_path)
        plt.close()

        print(f"[Compare] Saved heatmap for {src} -> {out_path}")

def main():
    framing, sentiment = load_data()

    plot_framing_grouped(framing)           # 图 1
    plot_framing_percentage(framing)        # 图 2
    plot_sentiment(sentiment)               # 图 3

    root = get_project_root()
    headlines_path = root / "data" / "results" / "headlines_with_sentiment.csv"
    print(f"[Compare] Loading per-headline sentiment from: {headlines_path}")

    headlines_df = pd.read_csv(headlines_path)
    plot_frame_sentiment_heatmap(headlines_df)

    print("[Compare] All comparison plots created.")


if __name__ == "__main__":
    main()
