'''
Electricity Markets - PyConvexity Version

Example:
Three bidding zones connected by transmission over 24 hours.

Description:
In this example we have bidirectional lossless transmission capacity 
between three bidding zones. The power transfer is treated as 
controllable (like an A/NTC (Available/Net Transfer Capacity) or HVDC line). 
Note that in the physical grid, power would flow passively according to the network impedances.

This is the PyConvexity version of the PyPSA example.
PyConvexity stores the model in a SQLite database and then builds/solves PyPSA networks from it.

Key differences from PyPSA:
1. Model is stored in a SQLite database instead of in-memory PyPSA network
2. Components (buses, generators, loads, links) are created using create_component()
3. Attributes are set using set_static_attribute() and set_timeseries_attribute()
4. Network time periods must be explicitly created
5. Carriers must be created before components
6. Solving uses solve_network() which builds a PyPSA network from the database

Source (original PyPSA):
https://docs.pypsa.org/latest/examples/simple-electricity-market-examples/
'''

import pyconvexity as px
import numpy as np
from datetime import datetime, timedelta
import os

# marginal costs in EUR/MWh
marginal_costs = {"Wind": 0, "Hydro": 0, "Coal": 30, "Gas": 60, "Oil": 80}

# power plant capacities (nominal powers in MW) in each country (not necessarily realistic)
power_plant_p_nom = {
    "South Africa": {"Coal": 35000, "Wind": 3000, "Gas": 8000, "Oil": 2000},
    "Mozambique": {
        "Hydro": 1200,
    },
    "Eswatini": {
        "Hydro": 600,
    },
}

# transmission capacities in MW (not necessarily realistic)
transmission = {
    "South Africa": {"Mozambique": 500, "Eswatini": 250},
    "Mozambique": {"Eswatini": 100},
}

# country electrical loads in MW (not necessarily realistic)
baseload = {"South Africa": 22000, "Mozambique": 250, "Eswatini": 150}

# Country coordinates (approximate center points)
country_coords = {
    "South Africa": {"latitude": -25.7, "longitude": 28.2},  # Pretoria/Johannesburg area
    "Mozambique": {"latitude": -25.9, "longitude": 32.6},    # Maputo area
    "Eswatini": {"latitude": -26.3, "longitude": 31.1},     # Mbabane area
}

# Database path - save in examples/dB directory
db_path = os.path.join(os.path.dirname(__file__), "dB", "001_three_country_example.db")
os.makedirs(os.path.dirname(db_path), exist_ok=True)

# Create database with schema
px.create_database_with_schema(db_path)

with px.database_context(db_path) as conn:
    # Create network with 24 hourly time periods
    # Time periods are automatically created from start_time, end_time, and time_resolution
    start_time = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
    end_time = start_time + timedelta(hours=23)
    
    network_req = px.CreateNetworkRequest(
        name="Three Country Electricity Market",
        description="Three bidding zones connected by transmission over 24 hours",
        start_time=start_time.strftime("%Y-%m-%d %H:%M:%S"),
        end_time=end_time.strftime("%Y-%m-%d %H:%M:%S"),
        time_resolution="PT1H",  # ISO 8601 duration format for hourly
    )
    
    px.create_network(conn, network_req)
    
    # Create carriers
    # Create carriers for each technology
    carriers = {}
    carrier_names = ["AC", "Wind", "Hydro", "Coal", "Gas", "Oil"]
    carrier_colors = {
        "AC": "#1f77b4",
        "Wind": "#2ca02c",
        "Hydro": "#17becf",
        "Coal": "#8C8C8C",
        "Gas": "#CCCCCC",
        "Oil": "#d62728",
    }
    
    for carrier_name in carrier_names:
        carrier_id = px.create_carrier(
            conn,
            name=carrier_name,
            co2_emissions=0.0,
            color=carrier_colors.get(carrier_name, "#808080"),
            nice_name=carrier_name,
        )
        carriers[carrier_name] = carrier_id
    
    countries = ["Eswatini", "Mozambique", "South Africa"]
    bus_ids = {}
    
    # Create buses with coordinates
    for country in countries:
        coords = country_coords[country]
        bus_id = px.create_component(
            conn,
            component_type="BUS",
            name=country,
            carrier_id=carriers["AC"],
            latitude=coords["latitude"],
            longitude=coords["longitude"],
        )
        bus_ids[country] = bus_id
        
        # Set bus voltage (optional, but good practice)
        px.set_static_attribute(conn, bus_id, "v_nom", px.StaticValue(400.0))
    
    # Create generators with coordinates (scattered around each country)
    generator_counter = {}  # Track generator count per country for offset calculation
    for country in countries:
        generator_counter[country] = 0
        coords = country_coords[country]
        
        for tech in power_plant_p_nom[country]:
            gen_name = f"{country} {tech}"
            
            # Add small offset to scatter generators around the country center
            # Offset based on generator index to create a pattern
            offset_lat = (generator_counter[country] % 3 - 1) * 0.3  # -0.3, 0, or 0.3
            offset_lon = (generator_counter[country] // 3) * 0.3      # 0, 0.3, 0.6, etc.
            
            gen_lat = coords["latitude"] + offset_lat
            gen_lon = coords["longitude"] + offset_lon
            generator_counter[country] += 1
            
            gen_id = px.create_component(
                conn,
                component_type="GENERATOR",
                name=gen_name,
                bus_id=bus_ids[country],
                carrier_id=carriers[tech],
                latitude=gen_lat,
                longitude=gen_lon,
            )
            
            # Set generator attributes
            # Explicitly convert to float to ensure type matching
            px.set_static_attribute(
                conn, gen_id, "p_nom", px.StaticValue(float(power_plant_p_nom[country][tech]))
            )
            px.set_static_attribute(
                conn, gen_id, "marginal_cost", px.StaticValue(float(marginal_costs[tech]))
            )
    
    # Create loads with 24 hours of randomly varying load
    for country in countries:
        load_name = f"{country} load"
        coords = country_coords[country]
        
        # Loads can be slightly offset from bus location (representing load centers)
        # Add small offset to distinguish from bus location
        load_lat = coords["latitude"] + 0.1
        load_lon = coords["longitude"] + 0.1
        
        load_id = px.create_component(
            conn,
            component_type="LOAD",
            name=load_name,
            bus_id=bus_ids[country],
            carrier_id=carriers["AC"],
            latitude=load_lat,
            longitude=load_lon,
        )
        
        # Generate random variation: ±20% around baseload
        load_variation = baseload[country] * (1 + np.random.uniform(-0.2, 0.2, size=24))
        
        # Set timeseries attribute for load
        px.set_timeseries_attribute(conn, load_id, "p_set", load_variation.tolist())
    
    # Create transmission links
    for country in countries:
        if country not in transmission:
            continue
        
        for other_country in countries:
            if other_country not in transmission[country]:
                continue
            
            link_name = f"{country} - {other_country} link"
            # Links get coordinates at midpoint between the two countries
            coords0 = country_coords[country]
            coords1 = country_coords[other_country]
            link_lat = (coords0["latitude"] + coords1["latitude"]) / 2.0
            link_lon = (coords0["longitude"] + coords1["longitude"]) / 2.0
            
            link_id = px.create_component(
                conn,
                component_type="LINK",
                name=link_name,
                bus0_id=bus_ids[country],
                bus1_id=bus_ids[other_country],
                carrier_id=carriers["AC"],
                latitude=link_lat,
                longitude=link_lon,
            )
            
            # Set link attributes
            # Explicitly convert to float to ensure type matching
            px.set_static_attribute(
                conn, link_id, "p_nom", px.StaticValue(float(transmission[country][other_country]))
            )
            px.set_static_attribute(conn, link_id, "p_min_pu", px.StaticValue(-1.0))
    
    # Set up dashboard configuration
    from pyconvexity.dashboard import set_dashboard_config, DashboardConfig, auto_layout
    
    charts = [{
        "id": "dispatch-main",
        "title": "Generation Dispatch by Carrier",
        "visible": True,
        "view": {
            "timeseries": {
                "component": "Generator",
                "attribute": "p",
                "group_by": "carrier"
            }
        }
    }]
    
    config = DashboardConfig(charts=charts, layout=auto_layout(charts))
    set_dashboard_config(conn, config)
    
    conn.commit()
    print("✅ Model created successfully!")

# Solve the network
print("🔧 Solving network...")
result = px.solve_network(
    db_path=db_path,
    solver_name="highs",
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}"),
    verbose=True,
)

print(f"✅ Optimization complete!")
print(f"   Success: {result.get('success', False)}")
print(f"   Objective: {result.get('objective_value', 'N/A')}")

# Export to Excel
print("\n📊 Exporting to Excel...")
from pyconvexity.io.excel_exporter import ExcelModelExporter

xlsx_path = db_path.replace('.db', '.xlsx')
exporter = ExcelModelExporter()
export_result = exporter.export_model_to_excel(
    db_path=db_path,
    output_path=xlsx_path,
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}") if progress else None,
)
print(f"✅ Excel export complete: {xlsx_path}")

