"""
Forecast vs Truth - Ensemble Price Forecasting Example

Example:
Single bus system with three gas generators (different prices), a wind generator,
and daily varying load over one week. Five ensemble scenarios with different
wind capacity factor forecasts. "Actual" scenario stores the true/measured values
for comparison with ensemble forecast outputs.

Purpose:
Demonstrates how to:
1. Create ensemble scenarios for probabilistic forecasting
2. Store "actual" (measured/true) values separately from forecasts
3. Compare ensemble outputs against actual values

The bus marginal_price output in each ensemble member can be compared
against the "Actual" marginal_price stored in the Actual scenario.
"""

import pyconvexity as px
import numpy as np
from datetime import datetime, timedelta
import os

# Configuration
PERIODS_PER_DAY = 24
NUM_DAYS = 7
TOTAL_PERIODS = PERIODS_PER_DAY * NUM_DAYS  # 168 hours
NUM_ENSEMBLE_MEMBERS = 5

# Generator specifications with locations
# Bus in middle of England (near Birmingham), gas plants in circle, wind in North Sea
BUS_COORDS = {"latitude": 52.5, "longitude": -1.5}  # Central England

GENERATORS = {
    "Gas Peaker": {
        "p_nom": 200, 
        "marginal_cost": 80,
        "latitude": 52.8,   # North of bus
        "longitude": -1.2,
    },
    "Gas CCGT": {
        "p_nom": 400, 
        "marginal_cost": 50,
        "latitude": 52.3,   # South of bus
        "longitude": -1.8,
    },
    "Gas Base": {
        "p_nom": 300, 
        "marginal_cost": 35,
        "latitude": 52.4,   # Southwest of bus
        "longitude": -2.0,
    },
    "Wind Farm": {
        "p_nom": 500, 
        "marginal_cost": 0,  # Zero marginal cost
        "latitude": 53.8,    # North Sea offshore
        "longitude": 2.5,
    },
}

LOAD_COORDS = {"latitude": 52.45, "longitude": -1.4}  # Near bus (slightly southeast)

# Database path
db_path = os.path.join(os.path.dirname(__file__), "dB", "004_forecast_truth.db")
os.makedirs(os.path.dirname(db_path), exist_ok=True)


def generate_daily_load_profile(num_days: int, base_load: float = 300) -> list:
    """Generate realistic daily varying load over multiple days."""
    hourly_factors = [
        0.6, 0.55, 0.5, 0.5, 0.55, 0.65,   # 00:00 - 05:00 (night)
        0.75, 0.85, 0.95, 1.0, 1.0, 0.95,  # 06:00 - 11:00 (morning ramp)
        0.9, 0.85, 0.85, 0.9, 0.95, 1.1,   # 12:00 - 17:00 (afternoon)
        1.15, 1.1, 1.0, 0.9, 0.8, 0.7,     # 18:00 - 23:00 (evening peak)
    ]
    
    profile = []
    for day in range(num_days):
        # Add some day-to-day variation
        day_factor = 1.0 + np.random.uniform(-0.1, 0.1)
        for hour_factor in hourly_factors:
            # Add small random noise
            noise = 1.0 + np.random.uniform(-0.05, 0.05)
            profile.append(base_load * hour_factor * day_factor * noise)
    
    return profile


def generate_wind_profile(num_periods: int, mean_capacity_factor: float = 0.35) -> list:
    """Generate realistic wind capacity factor profile with temporal correlation."""
    # Use random walk with mean reversion for temporal correlation
    profile = []
    current_cf = mean_capacity_factor
    
    for _ in range(num_periods):
        # Mean reversion factor
        reversion = 0.1 * (mean_capacity_factor - current_cf)
        # Random walk step
        step = np.random.normal(0, 0.08)
        # Update and clip
        current_cf = np.clip(current_cf + reversion + step, 0.0, 1.0)
        profile.append(current_cf)
    
    return profile


def generate_marginal_prices(num_periods: int, generators: dict, load_profile: list, wind_cf: list) -> list:
    """
    Generate realistic marginal prices based on merit order dispatch.
    Simple approximation: price depends on which generator is marginal.
    """
    prices = []
    
    # Sort generators by marginal cost (merit order)
    merit_order = sorted(
        [(name, info) for name, info in generators.items()],
        key=lambda x: x[1]["marginal_cost"]
    )
    
    for t in range(num_periods):
        load = load_profile[t]
        remaining_load = load
        marginal_price = 0
        
        # Wind first (zero cost)
        wind_gen = generators["Wind Farm"]["p_nom"] * wind_cf[t]
        remaining_load -= wind_gen
        
        # Dispatch other generators in merit order
        for name, info in merit_order:
            if name == "Wind Farm":
                continue
            if remaining_load <= 0:
                break
            dispatch = min(remaining_load, info["p_nom"])
            remaining_load -= dispatch
            marginal_price = info["marginal_cost"]
        
        # Add some random noise to reflect real market uncertainty
        price_noise = np.random.normal(0, 5)
        prices.append(max(0, marginal_price + price_noise))
    
    return prices


print("=" * 60)
print("PyConvexity - Forecast vs Truth Example")
print("=" * 60)

# Create fresh database
px.create_database_with_schema(db_path)
print(f"\n✅ Database created: {db_path}")

with px.database_context(db_path) as conn:
    # Create network metadata
    start_time = datetime(2024, 1, 1, 0, 0, 0)
    end_time = start_time + timedelta(hours=TOTAL_PERIODS - 1)
    
    network_req = px.CreateNetworkRequest(
        name="Forecast vs Truth Example",
        description="One-week ensemble price forecast with actual values",
        start_time=start_time.strftime("%Y-%m-%d %H:%M:%S"),
        end_time=end_time.strftime("%Y-%m-%d %H:%M:%S"),
        time_resolution="PT1H",
    )
    # Time periods are automatically created from start_time, end_time, and time_resolution
    px.create_network(conn, network_req)
    print(f"✅ Network created: {NUM_DAYS} days, {TOTAL_PERIODS} hours")
    
    # Create carriers
    carriers = {}
    for carrier_name, color in [("AC", "#1f77b4"), ("Gas", "#CCCCCC"), ("Wind", "#2ca02c")]:
        carrier_id = px.create_carrier(conn, name=carrier_name, color=color)
        carriers[carrier_name] = carrier_id
    
    # Create the single bus (central England)
    bus_id = px.create_component(
        conn,
        component_type="BUS",
        name="Main Bus",
        carrier_id=carriers["AC"],
        latitude=BUS_COORDS["latitude"],
        longitude=BUS_COORDS["longitude"],
    )
    px.set_static_attribute(conn, bus_id, "v_nom", px.StaticValue(110.0))
    print(f"✅ Bus created: Main Bus at ({BUS_COORDS['latitude']}, {BUS_COORDS['longitude']})")
    
    # Create generators (gas plants around bus, wind in North Sea)
    generator_ids = {}
    for gen_name, gen_info in GENERATORS.items():
        carrier = carriers["Wind"] if "Wind" in gen_name else carriers["Gas"]
        gen_id = px.create_component(
            conn,
            component_type="GENERATOR",
            name=gen_name,
            bus_id=bus_id,
            carrier_id=carrier,
            latitude=gen_info["latitude"],
            longitude=gen_info["longitude"],
        )
        px.set_static_attribute(conn, gen_id, "p_nom", px.StaticValue(float(gen_info["p_nom"])))
        px.set_static_attribute(conn, gen_id, "marginal_cost", px.StaticValue(float(gen_info["marginal_cost"])))
        generator_ids[gen_name] = gen_id
        print(f"✅ Generator: {gen_name} ({gen_info['p_nom']} MW, ${gen_info['marginal_cost']}/MWh) at ({gen_info['latitude']}, {gen_info['longitude']})")
    
    # Create load (near bus)
    load_id = px.create_component(
        conn,
        component_type="LOAD",
        name="System Load",
        bus_id=bus_id,
        carrier_id=carriers["AC"],
        latitude=LOAD_COORDS["latitude"],
        longitude=LOAD_COORDS["longitude"],
    )
    
    # Generate and set base load profile
    load_profile = generate_daily_load_profile(NUM_DAYS, base_load=350)
    px.set_timeseries_attribute(conn, load_id, "p_set", load_profile)
    print(f"✅ Load created: varying {min(load_profile):.0f}-{max(load_profile):.0f} MW at ({LOAD_COORDS['latitude']}, {LOAD_COORDS['longitude']})")
    
    # Generate base wind profile
    base_wind_cf = generate_wind_profile(TOTAL_PERIODS, mean_capacity_factor=0.35)
    px.set_timeseries_attribute(conn, generator_ids["Wind Farm"], "p_max_pu", base_wind_cf)
    
    # Create ensemble scenarios with different wind forecasts
    print(f"\n📊 Creating {NUM_ENSEMBLE_MEMBERS} ensemble scenarios...")
    ensemble_scenario_ids = []
    
    for i in range(1, NUM_ENSEMBLE_MEMBERS + 1):
        scenario_name = f"ensemble/{i:02d}"
        scenario_id = px.create_scenario(conn, name=scenario_name, description=f"Ensemble member {i}")
        ensemble_scenario_ids.append(scenario_id)
        
        # Generate perturbed wind profile for this ensemble member
        ensemble_wind_cf = generate_wind_profile(TOTAL_PERIODS, mean_capacity_factor=0.30 + 0.1 * np.random.random())
        
        # Set the wind p_max_pu for this scenario
        px.set_timeseries_attribute(
            conn, generator_ids["Wind Farm"], "p_max_pu", ensemble_wind_cf, scenario_id=scenario_id
        )
        print(f"   ✅ {scenario_name}: wind CF mean={np.mean(ensemble_wind_cf):.2f}")
    
    # Generate and store "actual" marginal prices
    # This represents what the true/measured price was (or would be)
    actual_wind_cf = generate_wind_profile(TOTAL_PERIODS, mean_capacity_factor=0.38)
    actual_prices = generate_marginal_prices(TOTAL_PERIODS, GENERATORS, load_profile, actual_wind_cf)
    
    # Store actual marginal price on the bus using the Actual scenario
    print(f"\n🎯 Setting actual marginal_price on bus...")
    px.set_actual_timeseries_value(conn, bus_id, "marginal_price", actual_prices)
    print(f"   ✅ Actual price range: ${min(actual_prices):.1f} - ${max(actual_prices):.1f}/MWh")
    print(f"   ✅ Actual price mean: ${np.mean(actual_prices):.1f}/MWh")
    
    # Also store the actual wind capacity factor
    px.set_actual_timeseries_value(conn, generator_ids["Wind Farm"], "p_max_pu", actual_wind_cf)
    print(f"   ✅ Actual wind CF mean: {np.mean(actual_wind_cf):.2f}")
    
    # Set up dashboard configuration
    from pyconvexity.dashboard import set_dashboard_config, DashboardConfig, auto_layout
    
    charts = [{
        "id": "price-ensemble",
        "title": "Electricity Price Forecast vs Actual",
        "visible": True,
        "view": {
            "timeseries": {
                "component": "Bus",
                "attribute": "marginal_price",
                "group_by": None
            }
        }
    }]
    
    config = DashboardConfig(charts=charts, layout=auto_layout(charts))
    set_dashboard_config(conn, config)
    
    conn.commit()

# Solve the base scenario
print("\n" + "=" * 60)
print("Solving Network")
print("=" * 60)

def format_objective(obj):
    """Format objective value, handling None/N/A cases."""
    if obj is None or obj == 'N/A':
        return 'N/A'
    return f"{obj:.2f}"

print("\n🔧 Solving base scenario...")
result = px.solve_network(
    db_path=db_path,
    solver_name="highs",
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}"),
)
print(f"✅ Base scenario: success={result.get('success', False)}, objective: {format_objective(result.get('objective_value'))}")

# Solve each ensemble scenario
print(f"\n🔧 Solving {NUM_ENSEMBLE_MEMBERS} ensemble scenarios...")
with px.database_context(db_path) as conn:
    scenarios = px.list_scenarios(conn)
    
for scenario in scenarios:
    if scenario.name.startswith("ensemble/"):
        result = px.solve_network(
            db_path=db_path,
            solver_name="highs",
            scenario_id=scenario.id,
        )
        print(f"   ✅ {scenario.name}: success={result.get('success', False)}, objective: {format_objective(result.get('objective_value'))}")

print("\n" + "=" * 60)
print("Model Summary")
print("=" * 60)
print(f"Time horizon: {NUM_DAYS} days ({TOTAL_PERIODS} hours)")
print(f"Components: 1 bus, 4 generators, 1 load")
print(f"Ensemble scenarios: {NUM_ENSEMBLE_MEMBERS}")
print(f"")
print("This model demonstrates forecast vs truth comparison:")
print("- Each ensemble scenario has different wind forecasts")
print("- The 'Actual' scenario stores true/measured values")
print("- Compare ensemble marginal_price outputs vs actual")
print("")
print(f"Database saved to: {db_path}")

# Export to Excel
print("\n📊 Exporting to Excel...")
from pyconvexity.io.excel_exporter import ExcelModelExporter

xlsx_path = db_path.replace('.db', '.xlsx')
exporter = ExcelModelExporter()
export_result = exporter.export_model_to_excel(
    db_path=db_path,
    output_path=xlsx_path,
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}") if progress else None,
)
print(f"✅ Excel export complete: {xlsx_path}")
print("=" * 60)

