'''
Negative Locational Marginal Prices (LMPs) Example - PyConvexity Version

This example demonstrates how negative locational marginal prices (LMPs) can occur 
in electricity systems due to line congestion. Using a simple 3-bus linearised DC 
power flow model, we reproduce the phenomenon where cheap generation is trapped behind 
a congested line, causing LMPs to drop below zero as the system redistributes power 
flows to meet demand.

This is the PyConvexity version of the PyPSA example.
PyConvexity stores the model in a SQLite database and then builds/solves PyPSA networks from it.

Key differences from PyPSA:
1. Model is stored in a SQLite database instead of in-memory PyPSA network
2. Components (buses, generators, loads, lines) are created using create_component()
3. Attributes are set using set_static_attribute() and set_timeseries_attribute()
4. Network time periods must be explicitly created
5. Carriers must be created before components
6. Solving uses solve_network() which builds a PyPSA network from the database

Source (original PyPSA):
Based on Kyri Baker's "3bus_LMP" example
'''

import pyconvexity as px
import numpy as np
from datetime import datetime, timedelta
import os

# Database path - save in examples/dB directory
db_path = os.path.join(os.path.dirname(__file__), "dB", "002_three_bus_lmp_example.db")
os.makedirs(os.path.dirname(db_path), exist_ok=True)

# Create database with schema
px.create_database_with_schema(db_path)

with px.database_context(db_path) as conn:
    # Create network with 24 hourly time periods
    # Time periods are automatically created from start_time, end_time, and time_resolution
    start_time = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
    end_time = start_time + timedelta(hours=23)
    
    network_req = px.CreateNetworkRequest(
        name="Three Bus LMP Example",
        description="3-bus system demonstrating negative LMPs due to line congestion",
        start_time=start_time.strftime("%Y-%m-%d %H:%M:%S"),
        end_time=end_time.strftime("%Y-%m-%d %H:%M:%S"),
        time_resolution="PT1H",  # ISO 8601 duration format for hourly
    )
    
    px.create_network(conn, network_req)
    
    # Create carriers
    carrier_id = px.create_carrier(
        conn,
        name="AC",
        co2_emissions=0.0,
        color="#1f77b4",
        nice_name="Electricity",
    )
    
    # Texas coordinates for the three buses
    # Using approximate locations in Texas to form a triangular layout
    bus_coords = {
        "Bus1": {"latitude": 30.3, "longitude": -97.7},   # Austin area (top-left)
        "Bus2": {"latitude": 30.3, "longitude": -95.4},   # Houston area (top-right)
        "Bus3": {"latitude": 29.4, "longitude": -98.5},   # San Antonio area (bottom, load)
    }
    
    # Create buses with coordinates
    bus_ids = {}
    for bus_name, coords in bus_coords.items():
        bus_id = px.create_component(
            conn,
            component_type="BUS",
            name=bus_name,
            carrier_id=carrier_id,
            latitude=coords["latitude"],
            longitude=coords["longitude"],
        )
        bus_ids[bus_name] = bus_id
        
        # Set bus voltage
        px.set_static_attribute(conn, bus_id, "v_nom", px.StaticValue(400.0))
    
    # Create generators
    # Gen1 at Bus1: 100 MW, marginal cost 10 €/MWh (cheap)
    gen1_id = px.create_component(
        conn,
        component_type="GENERATOR",
        name="Gen1",
        bus_id=bus_ids["Bus1"],
        carrier_id=carrier_id,
        latitude=bus_coords["Bus1"]["latitude"] + 0.05,  # Slightly offset from bus
        longitude=bus_coords["Bus1"]["longitude"] + 0.05,
    )
    px.set_static_attribute(conn, gen1_id, "p_nom", px.StaticValue(100.0))
    px.set_static_attribute(conn, gen1_id, "marginal_cost", px.StaticValue(10.0))
    
    # Gen2 at Bus2: 100 MW, marginal cost 20 €/MWh (medium)
    gen2_id = px.create_component(
        conn,
        component_type="GENERATOR",
        name="Gen2",
        bus_id=bus_ids["Bus2"],
        carrier_id=carrier_id,
        latitude=bus_coords["Bus2"]["latitude"] + 0.05,
        longitude=bus_coords["Bus2"]["longitude"] + 0.05,
    )
    px.set_static_attribute(conn, gen2_id, "p_nom", px.StaticValue(100.0))
    px.set_static_attribute(conn, gen2_id, "marginal_cost", px.StaticValue(20.0))
    
    # Gen3 at Bus3: 100 MW, marginal cost 100 €/MWh (expensive)
    gen3_id = px.create_component(
        conn,
        component_type="GENERATOR",
        name="Gen3",
        bus_id=bus_ids["Bus3"],
        carrier_id=carrier_id,
        latitude=bus_coords["Bus3"]["latitude"] + 0.05,
        longitude=bus_coords["Bus3"]["longitude"] + 0.05,
    )
    px.set_static_attribute(conn, gen3_id, "p_nom", px.StaticValue(100.0))
    px.set_static_attribute(conn, gen3_id, "marginal_cost", px.StaticValue(100.0))
    
    # Create load at Bus3 with 24 hours of randomly varying load
    # Base load is 100 MW with ±20% variation
    base_load = 100.0
    load_variation = base_load * (1 + np.random.uniform(-0.2, 0.2, size=24))
    
    load_id = px.create_component(
        conn,
        component_type="LOAD",
        name="Load3",
        bus_id=bus_ids["Bus3"],
        carrier_id=carrier_id,
        latitude=bus_coords["Bus3"]["latitude"] + 0.1,  # Offset from bus
        longitude=bus_coords["Bus3"]["longitude"] + 0.1,
    )
    
    # Set timeseries attribute for load
    px.set_timeseries_attribute(conn, load_id, "p_set", load_variation.tolist())
    
    # Create lines (transmission lines)
    # Line12: Bus1 to Bus2, x=1, s_nom=100 MW
    line12_id = px.create_component(
        conn,
        component_type="LINE",
        name="Line12",
        bus0_id=bus_ids["Bus1"],
        bus1_id=bus_ids["Bus2"],
        carrier_id=carrier_id,
        # Line coordinates at midpoint
        latitude=(bus_coords["Bus1"]["latitude"] + bus_coords["Bus2"]["latitude"]) / 2.0,
        longitude=(bus_coords["Bus1"]["longitude"] + bus_coords["Bus2"]["longitude"]) / 2.0,
    )
    px.set_static_attribute(conn, line12_id, "x", px.StaticValue(1.0))
    px.set_static_attribute(conn, line12_id, "s_nom", px.StaticValue(100.0))
    
    # Line23: Bus2 to Bus3, x=1, s_nom=100 MW
    line23_id = px.create_component(
        conn,
        component_type="LINE",
        name="Line23",
        bus0_id=bus_ids["Bus2"],
        bus1_id=bus_ids["Bus3"],
        carrier_id=carrier_id,
        # Line coordinates at midpoint
        latitude=(bus_coords["Bus2"]["latitude"] + bus_coords["Bus3"]["latitude"]) / 2.0,
        longitude=(bus_coords["Bus2"]["longitude"] + bus_coords["Bus3"]["longitude"]) / 2.0,
    )
    px.set_static_attribute(conn, line23_id, "x", px.StaticValue(1.0))
    px.set_static_attribute(conn, line23_id, "s_nom", px.StaticValue(100.0))
    
    # Line13: Bus1 to Bus3, x=1, s_nom=10 MW (BOTTLENECK - this is the key constraint!)
    line13_id = px.create_component(
        conn,
        component_type="LINE",
        name="Line13",
        bus0_id=bus_ids["Bus1"],
        bus1_id=bus_ids["Bus3"],
        carrier_id=carrier_id,
        # Line coordinates at midpoint
        latitude=(bus_coords["Bus1"]["latitude"] + bus_coords["Bus3"]["latitude"]) / 2.0,
        longitude=(bus_coords["Bus1"]["longitude"] + bus_coords["Bus3"]["longitude"]) / 2.0,
    )
    px.set_static_attribute(conn, line13_id, "x", px.StaticValue(1.0))
    px.set_static_attribute(conn, line13_id, "s_nom", px.StaticValue(10.0))  # Bottleneck!
    
    # Set up dashboard configuration
    from pyconvexity.dashboard import set_dashboard_config, DashboardConfig, auto_layout
    
    charts = [{
        "id": "lmp-main",
        "title": "Locational Marginal Prices",
        "visible": True,
        "view": {
            "timeseries": {
                "component": "Bus",
                "attribute": "marginal_price",
                "group_by": None
            }
        }
    }]
    
    config = DashboardConfig(charts=charts, layout=auto_layout(charts))
    set_dashboard_config(conn, config)
    
    conn.commit()
    print("✅ Model created successfully!")
    print(f"   📍 Location: Texas, USA")
    print(f"   🚌 Buses: 3 (Bus1: Austin area, Bus2: Houston area, Bus3: San Antonio area)")
    print(f"   ⚡ Generators: 3 (Gen1: 10 €/MWh, Gen2: 20 €/MWh, Gen3: 100 €/MWh)")
    print(f"   🏠 Load: 1 at Bus3 (100 MW base, ±20% variation over 24 hours)")
    print(f"   🔌 Lines: 3 (Line13 is bottlenecked at 10 MW)")
    print(f"   ⏰ Time periods: 24 hours")

# Solve the network
print("\n🔧 Solving network...")
result = px.solve_network(
    db_path=db_path,
    solver_name="highs",
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}"),
)

print(f"\n✅ Optimization complete!")
print(f"   Success: {result.get('success', False)}")
print(f"   Objective: {result.get('objective_value', 'N/A')}")
print(f"\n💡 Note: This example demonstrates negative LMPs when cheap generation")
print(f"   (Gen1 at Bus1) is trapped behind the congested Line13 (10 MW limit).")

# Export to Excel
print("\n📊 Exporting to Excel...")
from pyconvexity.io.excel_exporter import ExcelModelExporter

xlsx_path = db_path.replace('.db', '.xlsx')
exporter = ExcelModelExporter()
export_result = exporter.export_model_to_excel(
    db_path=db_path,
    output_path=xlsx_path,
    progress_callback=lambda progress, msg: print(f"  {progress}%: {msg}") if progress else None,
)
print(f"✅ Excel export complete: {xlsx_path}")

