'''
Multi-Year Capacity Expansion Planning - Ghana and Togo
PyConvexity Version

This example demonstrates multi-year capacity expansion planning between 2020-2024
for Ghana and Togo. The model optimizes investment decisions in generation capacity,
transmission links, and battery storage over a 5-year horizon.

Key features:
- 6-hourly resolution (4 snapshots per day)
- Continuous time axis (each year from 2020-2024 modeled)
- Year-on-year demand growth with random variation
- Technologies: Gas, Biomass, Hydro, Transmission Links, Batteries
- Constraint: No new gas capacity after 2025

Based on PyPSA multi-investment optimization example:
https://docs.pypsa.org/latest/examples/multi-investment-optimisation/
'''

import pyconvexity as px
import numpy as np
from datetime import datetime, timedelta
import os

# Database path - save in examples/dB directory
db_path = os.path.join(os.path.dirname(__file__), "dB", "003_ghana_togo_expansion.db")
os.makedirs(os.path.dirname(db_path), exist_ok=True)

# Create database with schema
px.create_database_with_schema(db_path)

with px.database_context(db_path) as conn:
    # Create network
    # Time period: 2020-2024 (5 years) - reduced for faster solving
    # 6-hourly resolution: 4 snapshots per day = 1460 snapshots per year
    start_time = datetime(2020, 1, 1, 0, 0, 0)
    end_time = datetime(2024, 12, 31, 18, 0, 0)  # Last 6-hourly period
    
    network_req = px.CreateNetworkRequest(
        name="Ghana-Togo Capacity Expansion 2020-2024",
        description="Multi-year capacity expansion planning for Ghana and Togo with 6-hourly resolution (5 years)",
        start_time=start_time.strftime("%Y-%m-%d %H:%M:%S"),
        end_time=end_time.strftime("%Y-%m-%d %H:%M:%S"),
        time_resolution="PT6H",  # 6-hourly intervals
    )
    
    # Time periods are automatically created from start_time, end_time, and time_resolution
    px.create_network(conn, network_req)
    
    # Calculate period count for logging
    timestamps = []
    current = start_time
    while current <= end_time:
        timestamps.append(current)
        current += timedelta(hours=6)
    
    print(f"✅ Created {len(timestamps)} time periods (6-hourly from 2020-2024)")
    print(f"   Years: 2020-2024 ({2024-2020+1} years)")
    print(f"   Snapshots per year: {len([t for t in timestamps if t.year == 2020])}")
    
    # Create carriers
    carriers = {}
    carrier_data = {
        "AC": {"co2": 0.0, "color": "#1f77b4", "nice": "Electricity"},
        "gas": {"co2": 0.4, "color": "#ff7f0e", "nice": "Natural Gas"},
        "biomass": {"co2": 0.0, "color": "#2ca02c", "nice": "Biomass"},
        "hydro": {"co2": 0.0, "color": "#17becf", "nice": "Hydroelectric"},
    }
    
    for name, data in carrier_data.items():
        carrier_id = px.create_carrier(
            conn,
            name=name,
            co2_emissions=data["co2"],
            color=data["color"],
            nice_name=data["nice"],
        )
        carriers[name] = carrier_id
    
    # Country coordinates (Ghana and Togo)
    country_coords = {
        "Ghana": {"latitude": 7.9, "longitude": -1.0},   # Central Ghana (Kumasi area)
        "Togo": {"latitude": 8.0, "longitude": 1.2},      # Central Togo (Lomé area)
    }
    
    # Create buses
    bus_ids = {}
    for country, coords in country_coords.items():
        bus_id = px.create_component(
            conn,
            component_type="BUS",
            name=country,
            carrier_id=carriers["AC"],
            latitude=coords["latitude"],
            longitude=coords["longitude"],
        )
        bus_ids[country] = bus_id
        px.set_static_attribute(conn, bus_id, "v_nom", px.StaticValue(400.0))
    
    # Calculate maximum demand for capacity limits
    # Base demand 2020
    total_base_demand_2020 = 2000.0 + 300.0  # Ghana + Togo
    # Projected demand 2024 with growth
    annual_growth_rate = 0.03  # 3% per year
    total_demand_2024 = total_base_demand_2020 * ((1 + annual_growth_rate) ** 4)
    # Add margin for random variation (±15%) and daily/seasonal factors
    max_demand_estimate = total_demand_2024 * 1.3  # Conservative estimate with margin
    
    # Technology parameters
    tech_params = {
        "gas": {
            "marginal_cost": 60.0,  # $/MWh
            "capital_cost": 800000.0,  # $/MW
            "lifetime": 25.0,  # years
            "max_build_year": 2025,  # No new gas after 2025
            "p_nom_max": max_demand_estimate * 0.4,  # Max 40% of total demand capacity
        },
        "biomass": {
            "marginal_cost": 45.0,  # $/MWh
            "capital_cost": 2000000.0,  # $/MW
            "lifetime": 30.0,  # years
            "max_build_year": 2040,  # Can build throughout
            "p_nom_max": max_demand_estimate * 0.2,  # Max 20% of total demand capacity
        },
        "hydro": {
            "marginal_cost": 5.0,  # $/MWh (very cheap)
            "capital_cost": 3000000.0,  # $/MW (expensive upfront)
            "lifetime": 50.0,  # years (long lifetime)
            "max_build_year": 2040,  # Can build throughout
            "p_nom_max": None,  # No limit (unlimited)
        },
    }
    
    print(f"📊 Capacity limits calculated:")
    print(f"   Max demand estimate: {max_demand_estimate:.1f} MW")
    print(f"   Gas max capacity: {tech_params['gas']['p_nom_max']:.1f} MW (40% of max demand)")
    print(f"   Biomass max capacity: {tech_params['biomass']['p_nom_max']:.1f} MW (20% of max demand)")
    print(f"   Hydro max capacity: Unlimited")
    
    # Create generators for each country and technology
    # We'll create multiple generator options with different build_year constraints
    generator_counter = {"Ghana": 0, "Togo": 0}
    
    for country in ["Ghana", "Togo"]:
        coords = country_coords[country]
        
        for tech_name, params in tech_params.items():
            # Create generators that can be built in different periods
            # For gas: only allow building in 2020-2025
            # For biomass and hydro: allow building throughout 2020-2040
            
            if tech_name == "gas":
                # Gas can only be built in first 5 years (2020-2025)
                build_years = list(range(2020, 2026))  # 2020, 2021, 2022, 2023, 2024, 2025
            else:
                # Biomass and hydro can be built throughout (every year for 5-year period)
                build_years = list(range(2020, 2025))  # 2020, 2021, 2022, 2023, 2024
            
            for build_year in build_years:
                gen_name = f"{country} {tech_name} {build_year}"
                
                # Offset generators slightly
                offset = generator_counter[country] * 0.05
                gen_lat = coords["latitude"] + offset
                gen_lon = coords["longitude"] + offset
                generator_counter[country] += 1
                
                gen_id = px.create_component(
                    conn,
                    component_type="GENERATOR",
                    name=gen_name,
                    bus_id=bus_ids[country],
                    carrier_id=carriers[tech_name],
                    latitude=gen_lat,
                    longitude=gen_lon,
                )
                
                # Set generator attributes
                px.set_static_attribute(conn, gen_id, "p_nom", px.StaticValue(0.0))  # Start with 0
                px.set_static_attribute(conn, gen_id, "p_nom_extendable", px.StaticValue(True))  # Boolean
                px.set_static_attribute(conn, gen_id, "marginal_cost", px.StaticValue(float(params["marginal_cost"])))
                px.set_static_attribute(conn, gen_id, "capital_cost", px.StaticValue(float(params["capital_cost"])))
                px.set_static_attribute(conn, gen_id, "build_year", px.StaticValue(int(build_year)))
                px.set_static_attribute(conn, gen_id, "lifetime", px.StaticValue(float(params["lifetime"])))
                
                # Set capacity limits: gas 40%, biomass 20%, hydro unlimited
                if params["p_nom_max"] is not None:
                    px.set_static_attribute(conn, gen_id, "p_nom_max", px.StaticValue(float(params["p_nom_max"])))
    
    # Create transmission link between Ghana and Togo
    # Link can be expanded throughout the planning period
    link_id = px.create_component(
        conn,
        component_type="LINK",
        name="Ghana-Togo Interconnector",
        bus0_id=bus_ids["Ghana"],
        bus1_id=bus_ids["Togo"],
        carrier_id=carriers["AC"],
        latitude=(country_coords["Ghana"]["latitude"] + country_coords["Togo"]["latitude"]) / 2.0,
        longitude=(country_coords["Ghana"]["longitude"] + country_coords["Togo"]["longitude"]) / 2.0,
    )
    
    # Link attributes - can be expanded
    px.set_static_attribute(conn, link_id, "p_nom", px.StaticValue(0.0))  # Start with 0
    px.set_static_attribute(conn, link_id, "p_nom_extendable", px.StaticValue(True))  # Boolean
    px.set_static_attribute(conn, link_id, "p_min_pu", px.StaticValue(-1.0))  # Bidirectional
    px.set_static_attribute(conn, link_id, "capital_cost", px.StaticValue(500000.0))  # $/MW
    px.set_static_attribute(conn, link_id, "build_year", px.StaticValue(2020))
    px.set_static_attribute(conn, link_id, "lifetime", px.StaticValue(40.0))  # Long lifetime
    
    # Create battery storage units for each country
    # Batteries can be built throughout the planning period
    for country in ["Ghana", "Togo"]:
        coords = country_coords[country]
        
        # Create battery options for different build years (every year for 5-year period)
        for build_year in range(2020, 2025):
            battery_name = f"{country} Battery {build_year}"
            
            battery_id = px.create_component(
                conn,
                component_type="STORAGE_UNIT",
                name=battery_name,
                bus_id=bus_ids[country],
                carrier_id=carriers["AC"],
                latitude=coords["latitude"] + 0.1,
                longitude=coords["longitude"] + 0.1,
            )
            
            # Battery attributes
            px.set_static_attribute(conn, battery_id, "p_nom", px.StaticValue(0.0))  # Start with 0
            px.set_static_attribute(conn, battery_id, "p_nom_extendable", px.StaticValue(True))  # Boolean
            px.set_static_attribute(conn, battery_id, "capital_cost", px.StaticValue(600000.0))  # $/MW
            px.set_static_attribute(conn, battery_id, "build_year", px.StaticValue(int(build_year)))
            px.set_static_attribute(conn, battery_id, "lifetime", px.StaticValue(15.0))  # 15 years
            px.set_static_attribute(conn, battery_id, "max_hours", px.StaticValue(4.0))  # 4 hours storage
            px.set_static_attribute(conn, battery_id, "efficiency_store", px.StaticValue(0.9))
            px.set_static_attribute(conn, battery_id, "efficiency_dispatch", px.StaticValue(0.9))
            px.set_static_attribute(conn, battery_id, "p_min_pu", px.StaticValue(-1.0))  # Can charge/discharge
            px.set_static_attribute(conn, battery_id, "p_max_pu", px.StaticValue(1.0))
            px.set_static_attribute(conn, battery_id, "cyclic_state_of_charge", px.StaticValue(True))  # Boolean
            px.set_static_attribute(conn, battery_id, "cyclic_state_of_charge_per_period", px.StaticValue(True))  # Boolean
    
    # Create loads with year-on-year growth and random variation
    # Base demand in 2020 (MW)
    base_demand_2020 = {
        "Ghana": 2000.0,  # MW
        "Togo": 300.0,   # MW
    }
    
    # Annual growth rate
    annual_growth_rate = 0.03  # 3% per year
    
    # Generate demand profiles for each country
    for country in ["Ghana", "Togo"]:
        load_id = px.create_component(
            conn,
            component_type="LOAD",
            name=f"{country} Load",
            bus_id=bus_ids[country],
            carrier_id=carriers["AC"],
            latitude=country_coords[country]["latitude"] + 0.15,
            longitude=country_coords[country]["longitude"] + 0.15,
        )
        
        # Generate demand timeseries
        demand_values = []
        for timestamp in timestamps:
            year = timestamp.year
            year_index = year - 2020  # Years since 2020
            
            # Base demand grows year-on-year
            base_demand = base_demand_2020[country] * ((1 + annual_growth_rate) ** year_index)
            
            # Add random variation within the year (±15%)
            # Use day of year for some seasonality
            day_of_year = timestamp.timetuple().tm_yday
            seasonal_factor = 1.0 + 0.1 * np.sin(2 * np.pi * day_of_year / 365.25)  # Seasonal variation
            
            # Random variation
            random_factor = 1.0 + np.random.uniform(-0.15, 0.15)
            
            # Hour of day factor (higher demand during day)
            hour = timestamp.hour
            if 6 <= hour <= 22:
                daily_factor = 1.1  # Higher during day
            else:
                daily_factor = 0.8  # Lower at night
            
            demand = base_demand * seasonal_factor * random_factor * daily_factor
            demand_values.append(max(0.0, demand))  # Ensure non-negative
        
        # Set timeseries attribute
        px.set_timeseries_attribute(conn, load_id, "p_set", demand_values)
        
        print(f"✅ Created load for {country}")
        print(f"   Base demand 2020: {base_demand_2020[country]:.1f} MW")
        print(f"   Projected demand 2024: {base_demand_2020[country] * ((1 + annual_growth_rate) ** 4):.1f} MW")
    
    # Add emissions constraint: linearly reduce to zero by 2024
    # Calculate initial emissions estimate (based on gas capacity limit and typical emissions)
    # Gas has 0.4 tCO2/MWh, assume max gas generation could be 40% of max demand
    initial_emissions_estimate = max_demand_estimate * 0.4 * 0.4 * 8760  # Max gas capacity * emissions factor * hours/year
    # Use a more conservative initial value
    initial_emissions = initial_emissions_estimate * 0.5  # Start at 50% of theoretical max
    
    # Create constraint code for emissions limits
    constraint_code = f"""
# Primary Energy Limit - Limit total CO2 emissions
# Linearly reduce emissions from {initial_emissions:.0f} tCO2 in 2020 to 0 tCO2 in 2024

import numpy as np
import pandas as pd

years = np.arange(2020, 2025)
emissions = np.linspace({initial_emissions:.0f}, 0, len(years))

df = pd.DataFrame(
    {{
        'year' : years,
        'emissions' : emissions
    }}
)

for i, row in df.iterrows():
    n.add(
        "GlobalConstraint",
        "co2_limit_" + str(int(row.year)),
        type="primary_energy",
        carrier_attribute="co2_emissions",
        investment_period=int(row.year),
        sense="<=",
        constant=float(row.emissions),
    )

print("Added CO2 emissions limit constraint: linearly reducing from {initial_emissions:.0f} tCO2 (2020) to 0 tCO2 (2024)")
""".strip()
    
    # Create constraint component
    constraint_id = px.create_component(
        conn,
        component_type="CONSTRAINT",
        name="CO2 Emissions Limit",
    )
    
    # Set constraint attributes
    px.set_static_attribute(conn, constraint_id, "constraint_code", px.StaticValue(constraint_code))
    px.set_static_attribute(conn, constraint_id, "description", px.StaticValue("Linearly reduce CO2 emissions from 2020 to zero by 2024"))
    px.set_static_attribute(conn, constraint_id, "is_active", px.StaticValue(True))
    px.set_static_attribute(conn, constraint_id, "priority", px.StaticValue(1))
    
    print(f"\n✅ Added emissions constraint:")
    print(f"   Initial emissions (2020): {initial_emissions:.0f} tCO2")
    print(f"   Final emissions (2024): 0 tCO2")
    print(f"   Linear reduction over 5 years")
    
    # Set up dashboard configuration
    from pyconvexity.dashboard import set_dashboard_config, DashboardConfig, auto_layout
    
    charts = [{
        "id": "capacity-main",
        "title": "Optimal Power Capacity by Carrier",
        "visible": True,
        "view": {
            "statistic": {
                "statistic": "optimal_capacity",
                "metric": "capacity"
            }
        }
    }]
    
    config = DashboardConfig(charts=charts, layout=auto_layout(charts))
    set_dashboard_config(conn, config)
    
    conn.commit()
    print("\n✅ Model created successfully!")
    print(f"   📍 Countries: Ghana and Togo")
    print(f"   ⏰ Period: 2020-2024 (5 years)")
    print(f"   📊 Resolution: 6-hourly ({len(timestamps)} snapshots)")
    print(f"   ⚡ Technologies: Gas (until 2025), Biomass, Hydro, Links, Batteries")
    print(f"   📈 Demand: Year-on-year growth (3% annually) with random variation")
    print(f"   🌍 Emissions: Linear reduction to zero by 2024")

# Solve the network with multi-investment period optimization
print("\n🔧 Solving network with multi-investment period optimization...")
result = px.solve_network(
    db_path=db_path,
    solver_name="highs",
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}"),
)

print(f"\n✅ Optimization complete!")
print(f"   Success: {result.get('success', False)}")
print(f"   Objective: {result.get('objective_value', 'N/A')}")
print(f"\n💡 This model optimizes capacity expansion decisions over 5 years,")
print(f"   determining when to build new generation, transmission, and storage")
print(f"   to meet growing demand at minimum cost.")

# Export to Excel
print("\n📊 Exporting to Excel...")
from pyconvexity.io.excel_exporter import ExcelModelExporter

xlsx_path = db_path.replace('.db', '.xlsx')
exporter = ExcelModelExporter()
export_result = exporter.export_model_to_excel(
    db_path=db_path,
    output_path=xlsx_path,
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}") if progress else None,
)
print(f"✅ Excel export complete: {xlsx_path}")

