#---------------------------------------------------------------------------------------------------
#Interactive Numerical Simulation of the Atlantic Meridional Overturning Circulation (AMOC)
#Based on the 4-box model of the AMOC by Stefan Rahmstorf (1996)
#The box model is implemented, such that it represents realistic values of the ocean (current)
#---------------------------------------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle

# Initial conditions
dt = 0.5  # Time step for simulation
initial_dt = dt  # Store initial dt value for resetting
Tn0 = 2  # Forcing Temperature in the north
Ts0 = 7  # Forcing Temperature in the south
Fn = 0.2
Fs = -0.2

Se_0 = 35.09  # Equatorial salinity (psu)
Te_0 = 25.39  # Equatorial temperature (°C)
Sn_0 = 34.75  # Northern salinity
Tn_0 = 4.2   # Northern temperature
Ss_0 = 34.5  # Southern salinity
Ts_0 = 3.87   # Southern temperature
Sd_0 = 34.9  # Deep water salinity
Td_0 = 1.54   # Deep water temperature

# Ocean volumes in 10^16 m^3 
Ve = 4.1  # Equatorial box volume
Vn = 6.2  # Northern box volume 
Vs = 8.3  # Southern box volume
Vd = 23.0  # Deep ocean volume

boxE = [Se_0, Te_0, Ve]
boxN = [Sn_0, Tn_0, Vn]
boxS = [Ss_0, Ts_0, Vs] 
boxD = [Sd_0, Td_0, Vd]


# Global simulation parameters
Temp = True  # Flag for temperature calculations
hosing_active = False  # Flag for gradual hosing
hosing_rate = 0.0  # Default hosing strength
show_density = True  # Always show values

# Density function
def density(box):
    S = box[0]
    T = box[1]

    S_ref = 35 
    T_ref = 10
    rho_0 = 1027

   # Thermal expansion coefficient (per °C)
    alpha = 0.0002  # Typical value for seawater
    
    # Haline contraction coefficient (per psu)
    beta = 0.0008  # Typical value for seawater
    
    # Calculate density anomaly
    delta_rho = rho_0 * (beta * (S - S_ref) - alpha * (T - T_ref))
    rho = rho_0 + delta_rho

    return rho

# Anomaly function
def anomaly(boxN, boxS):
    m = density(boxN) - density(boxS) 
    return m

# Timestep function
def timestep(boxE, boxN, boxS, boxD, dt, Tn0, Ts0, Fs, Fn, Te_0):
    global Temp  # Use the global Temp flag
    
    Se, Te, Ve = boxE
    Sn, Tn, Vn = boxN
    Ss, Ts, Vs = boxS
    Sd, Td, Vd = boxD

    m = anomaly(boxN, boxS)

    if m > 0:
        Sn = Sn + (m * (Se - Sn) - Fn) * dt/Vn
        Ss = Ss + (m * (Sd - Ss) + Fs) * dt/Vs
        Se = Se + (m * (Ss - Se) + Fn - Fs) * dt/Ve
        Sd = Sd + (m * (Sn - Sd)) * dt/Vd  
        if Temp:  # Only update temperatures if Temp is True
            Tn = Tn + (m * (Te - Tn) - (Tn - (Tn0))) * dt/Vn
            Ts = Ts + (m * (Td - Ts) - (Ts - (Ts0))) * dt/Vs
            Te = Te + (m * (Ts - Te - (Te - Te_0))) * dt/Ve
            Td = Td + (m * (Tn - Td)) * dt/Vd

    elif m <= 0:
        n = -m
        Sn = Sn + (n * (Sd - Sn) - Fn) * dt/Vn
        Ss = Ss + (n * (Se - Ss) + Fs) * dt/Vs
        Se = Se + (n * (Sn - Se) + Fn - Fs) * dt/Ve
        Sd = Sd + (n * (Ss - Sd)) * dt/Vd
        if Temp:  # Only update temperatures if Temp is True
            Tn = Tn + (n * (Td - Tn) - (Tn - (Tn0))) * dt/Vn
            Ts = Ts + (n * (Te - Ts) - (Ts - (Ts0))) * dt/Vs
            Te = Te + (n * (Tn - Te - (Te - Te_0))) * dt/Ve
            Td = Td + (n * (Ts - Td)) * dt/Vd

    boxE = [Se, Te, Ve]
    boxN = [Sn, Tn, Ve]
    boxS = [Ss, Ts, Vs]
    boxD = [Sd, Td, Vd]

    c = 42 #c conversion factor density difference in amoc strength
    amoc_strength = c * anomaly(boxN, boxS) 
    return boxE, boxN, boxS, boxD, amoc_strength

def toggle_temperature(event):
    global Temp
    Temp = not Temp  # Toggle temperature calculations
    print(f"Temperature calculations {'enabled' if Temp else 'disabled'}")

def restart_simulation(event):
    global boxE, boxN, boxS, boxD, current_time, xdata, ydata_m, ydata_Tn0, ydata_Ts0, ydata_Fn, ydata_Fs, hosing_active, hosing_periods
    
    # Reset boxes to initial conditions
    boxE = [Se_0, Te_0, Ve]
    boxN = [Sn_0, Tn_0, Vn]
    boxS = [Ss_0, Ts_0, Vs]
    boxD = [Sd_0, Td_0, Vd]
    
    # Reset time and data
    current_time = 0
    xdata, ydata_m = [], []
    ydata_Tn0, ydata_Ts0, ydata_Fn, ydata_Fs = [], [], [], []
    
    # Reset hosing visualization data
    if hasattr(animate, 'ydata_hosing'):
        animate.ydata_hosing = []
    
    # Reset hosing state
    hosing_active = False
    hosing_periods = []
    hosing_button.color = axcolor
    hosing_button.hovercolor = '0.975'
    s_hosing.set_val(0)  # Reset slider to 0
    
    # Clear all hosing rectangles
    for patch in ax1.patches[:]:
        if isinstance(patch, Rectangle):
            patch.remove()
    
    # Reset all line data
    line1.set_data([], [])
    line3_Tn0.set_data([], [])
    line4_Ts0.set_data([], [])
    line5_Fn.set_data([], [])
    line6_Fs.set_data([], [])
    line7_hosing.set_data([], [])  # Clear hosing line
    
    # Reset plot limits
    ax1.relim()
    ax1.autoscale_view()
    ax2.relim()
    ax2.autoscale_view()
    
    print("Simulation fully reset!")
    fig.canvas.draw_idle()
    
    # Clear all hosing rectangles (improved method)
    for patch in ax1.patches[:]:  # Iterate over a copy of the list
        if isinstance(patch, Rectangle):
            ax1.patches.remove(patch)
    
    # Reset plot limits
    ax1.relim()
    ax1.autoscale_view()
    ax2.relim() 
    ax2.autoscale_view()
    
    # Clear and redraw lines
    line1.set_data([], [])
    line3_Tn0.set_data([], [])
    line4_Ts0.set_data([], [])
    line5_Fn.set_data([], [])
    line6_Fs.set_data([], [])
    
    print("Simulation fully reset!")
    fig.canvas.draw_idle()

# Set up the figure and axes
fig, ax1 = plt.subplots()
plt.subplots_adjust(left=0.25, bottom=0.4)  # Adjusted to make space for buttons

# Primary y-axis for anomalies
ax1.set_xlabel('Time')
ax1.set_ylabel('AMOC [Sv]', color='tab:blue')
line1, = ax1.plot([], [], lw=4, color='tab:blue', label='AMOC-Strength')
ax1.grid(True)

# Secondary y-axis for Tn0, Ts0, and Fn, Fs
ax2 = ax1.twinx()
ax2.set_ylabel('Tn0, Ts0, Fn, Fs and Hosing Strength', color='tab:red')
line3_Tn0, = ax2.plot([], [], lw=2, color='tab:orange', linestyle='--', label='Tn0')
line4_Ts0, = ax2.plot([], [], lw=2, color='tab:pink', linestyle='--', label='Ts0')
line5_Fn, = ax2.plot([], [], lw=2, color='tab:green', linestyle='--', label='Fn')
line6_Fs, = ax2.plot([], [], lw=2, color='tab:blue', linestyle='--', label='Fs')
line7_hosing, = ax2.plot([], [], lw=2, color='tab:purple', linestyle=':', label='Hosing')

# Initialize data
xdata, ydata_m = [], []
ydata_Tn0, ydata_Ts0, ydata_Fn, ydata_Fs = [], [], [], []

# Simulation state
current_time = 0
running = True
hosing_active = False
hosing_rate = 0.0
hosing_start_time = None
hosing_periods = []

# Function to handle the gradual hosing button click
def start_hosing(event):
    global hosing_active, hosing_start_time, dt
    hosing_active = not hosing_active
    if hosing_active:
        print("Gradual hosing started!")
        dt = initial_dt
        s_dt.set_val(dt)
        hosing_start_time = current_time
        hosing_button.color = 'red'  # Change to red when active
        hosing_button.hovercolor = 'darkred'
    else:
        print("Gradual hosing stopped!")
        if hosing_start_time is not None:
            hosing_periods.append((hosing_start_time, current_time))
            hosing_start_time = None
        if hosing_periods:
            for start, end in hosing_periods:
                width = end - start
                rect = Rectangle((start, ax1.get_ylim()[0]), width, 
                                ax1.get_ylim()[1]-ax1.get_ylim()[0],
                                color='pink')
                ax1.add_patch(rect)
        hosing_button.color = axcolor  # Return to original color
        hosing_button.hovercolor = '0.975'
    fig.canvas.draw_idle()
# Function to update hosing strength
def update_hosing_strength(val):
    global hosing_rate
    hosing_rate = s_hosing.val

# Function to update parameters when sliders are changed
def update(val):
    global Tn0, Ts0, Fn, Fs, dt
    Tn0 = s_Tn0.val
    Ts0 = s_Ts0.val
    Fn = s_Fn.val
    Fs = s_Fs.val
    dt = s_dt.val

# Animation function
def animate(frame):
    global boxE, boxN, boxS, boxD, current_time, xdata, ydata_m, ydata_Tn0, ydata_Ts0, ydata_Fn, ydata_Fs, hosing_active, dt, hosing_rate

    if hosing_active:
        boxS[0] += hosing_rate * dt /Vs
        boxN[0] -= hosing_rate * dt /Vn

    boxE, boxN, boxS, boxD, m = timestep(boxE, boxN, boxS, boxD, dt, Tn0, Ts0, Fs, Fn, Te_0)

    xdata.append(current_time)
    ydata_m.append(m)
    ydata_Tn0.append(Tn0)
    ydata_Ts0.append(Ts0)
    ydata_Fn.append(Fn)
    ydata_Fs.append(Fs)
    
    # New: Track hosing strength (0 when inactive)
    current_hosing = hosing_rate if hosing_active else 0
    if not hasattr(animate, 'ydata_hosing'):
        animate.ydata_hosing = []
    animate.ydata_hosing.append(current_hosing)
    
    # Keep last 2000 points
    xdata = xdata[-2000:]
    ydata_m = ydata_m[-2000:]
    ydata_Tn0 = ydata_Tn0[-2000:]
    ydata_Ts0 = ydata_Ts0[-2000:] 
    ydata_Fn = ydata_Fn[-2000:]
    ydata_Fs = ydata_Fs[-2000:]
    animate.ydata_hosing = animate.ydata_hosing[-2000:]

    # Update all lines
    line1.set_data(xdata, ydata_m)
    line3_Tn0.set_data(xdata, ydata_Tn0)
    line4_Ts0.set_data(xdata, ydata_Ts0)
    line5_Fn.set_data(xdata, ydata_Fn)
    line6_Fs.set_data(xdata, ydata_Fs)
    line7_hosing.set_data(xdata, animate.ydata_hosing)  # New line

    current_time += dt

    ax1.relim()
    ax1.autoscale_view()
    ax2.relim()
    ax2.autoscale_view()

    temp_status = "ON" if Temp else "OFF"
    
    # Create bold versions of N and S values
    bold_N = f"\033[1mN={boxN[1]:.2f}\033[0m"  # Using ANSI bold
    bold_S = f"\033[1mS={boxS[1]:.2f}\033[0m"
    
    # Alternative matplotlib bold (choose one method):
    temperatures = (f"Temperature: E={boxE[1]:.2f}, "
                   f"$\mathbf{{N}}$={boxN[1]:.2f}, "  # Math bold
                   f"$\mathbf{{S}}$={boxS[1]:.2f}, "
                   f"D={boxD[1]:.2f} (Temp: {temp_status})")
    
    salinities = (f"Salinity: E={boxE[0]:.2f}, "
                 f"$\mathbf{{N}}$={boxN[0]:.2f}, "
                 f"$\mathbf{{S}}$={boxS[0]:.2f}, "
                 f"D={boxD[0]:.2f}")
    
    densities = (f"Density: E={density(boxE):.2f}, "
                f"$\mathbf{{N}}$={density(boxN):.2f}, "
                f"$\mathbf{{S}}$={density(boxS):.2f}, "
                f"D={density(boxD):.2f}")
    
    ax1.set_title(f"{temperatures} | {salinities} | {densities}", 
                 fontsize=10)
    
    return line1, line3_Tn0, line4_Ts0, line5_Fn, line6_Fs, line7_hosing


# Create sliders
axcolor = 'lightgoldenrodyellow'
ax_Tn0 = plt.axes([0.25, 0.25, 0.65, 0.03], facecolor=axcolor)
ax_Ts0 = plt.axes([0.25, 0.20, 0.65, 0.03], facecolor=axcolor)
ax_Fn = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
ax_Fs = plt.axes([0.25, 0.10, 0.65, 0.03], facecolor=axcolor)
ax_dt = plt.axes([0.25, 0.05, 0.65, 0.03], facecolor=axcolor)
ax_hosing_strength = plt.axes([0.25, 0.30, 0.65, 0.03], facecolor=axcolor)

s_Tn0 = Slider(ax_Tn0, 'Tn0', 0.0, 20.0, valinit=Tn0)
s_Ts0 = Slider(ax_Ts0, 'Ts0', 0.0, 20.0, valinit=Ts0)
s_Fn = Slider(ax_Fn, 'Fn', -1.5,1.5, valinit=Fn)
s_Fs = Slider(ax_Fs, 'Fs', -1.5, 1.5, valinit=Fs)
s_dt = Slider(ax_dt, 'dt', 0.0001, 3, valinit=dt)
s_hosing = Slider(ax_hosing_strength, 'Hosing Strength', -1.5, 1.5, valinit=0.0)

s_Tn0.on_changed(update)
s_Ts0.on_changed(update)
s_Fn.on_changed(update)
s_Fs.on_changed(update)
s_dt.on_changed(update)
s_hosing.on_changed(update_hosing_strength)

# Function to print box values
def print_box_values(event):
    print("\nCurrent Box Values:")
    print(f"Box E - Salinity: {boxE[0]:.2f}, Temperature: {boxE[1]:.2f}")
    print(f"Box N - Salinity: {boxN[0]:.2f}, Temperature: {boxN[1]:.2f}")
    print(f"Box S - Salinity: {boxS[0]:.2f}, Temperature: {boxS[1]:.2f}")
    print(f"Box D - Salinity: {boxD[0]:.2f}, Temperature: {boxD[1]:.2f}")
    print("\nDensities:")
    print(f"Box E: {density(boxE):.2f}")
    print(f"Box N: {density(boxN):.2f}")
    print(f"Box S: {density(boxS):.2f}")
    print(f"Box D: {density(boxD):.2f}")

# Create buttons (without the Toggle Values button)
button_height = 0.04
button_width = 0.15
left_margin = 0.03
vertical_spacing = 0.01
y_pos = 0.25 - button_height - vertical_spacing


# Toggle Temperature button
ax_temp = plt.axes([left_margin, y_pos, button_width, button_height], facecolor=axcolor)
temp_button = Button(ax_temp, 'Temp Constant', color=axcolor, hovercolor='0.975')
temp_button.on_clicked(toggle_temperature)
y_pos -= (button_height + vertical_spacing)

# Gradual Hosing button
ax_hosing = plt.axes([left_margin, y_pos, button_width, button_height], facecolor=axcolor)
hosing_button = Button(ax_hosing, 'Toggle Hosing', color=axcolor, hovercolor='0.975')
hosing_button.on_clicked(start_hosing)
y_pos -= (button_height + vertical_spacing)


# Print Values button
ax_print = plt.axes([left_margin, y_pos, button_width, button_height], facecolor=axcolor)
print_button = Button(ax_print, 'Print Values', color=axcolor, hovercolor='0.975')
print_button.on_clicked(print_box_values)
y_pos -= (button_height + vertical_spacing)


# Add Restart button
ax_restart = plt.axes([left_margin, y_pos, button_width, button_height], facecolor=axcolor)
restart_button = Button(ax_restart, 'Restart', color='lightblue', hovercolor='lightcyan')
restart_button.on_clicked(restart_simulation)



# Run the animation
ani = FuncAnimation(fig, animate, interval=50, blit=False)

# Legends
lines = [line1, line3_Tn0, line4_Ts0, line5_Fn, line6_Fs, line7_hosing]  # Added line7_hosing
labels = [line.get_label() for line in lines]
ax1.legend(lines[:1], labels[:1], loc='upper left')
ax2.legend(lines[1:], labels[1:], loc='upper right')

plt.show()
