import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# -------------------------
# Parameters
# -------------------------
L = 0.5            # spinal length (m)
Nx = 300           # spatial points
dx = L / Nx

T = 10.0           # total time (s)
dt = 0.002
Nt = int(T / dt)

D = 1e-4           # diffusion coefficient (m^2/s)
v0 = 0.01          # peak CSF velocity (m/s)
freq = 1.2         # cardiac frequency (Hz)

device = "cpu"

# -------------------------
# Spatial grid
# -------------------------
x = torch.linspace(0, L, Nx, device=device)

# -------------------------
# Initial condition: bolus injection
# -------------------------
C = torch.zeros(Nx, device=device)
inj_center = int(0.15 * Nx)
C[inj_center-3:inj_center+3] = 1.0

# -------------------------
# Finite difference operators
# -------------------------
def laplacian(C):
    return (torch.roll(C, 1) - 2*C + torch.roll(C, -1)) / dx**2

def gradient(C):
    return (torch.roll(C, -1) - torch.roll(C, 1)) / (2*dx)

# -------------------------
# Visualization setup
# -------------------------
fig, ax = plt.subplots()
line, = ax.plot(x.numpy(), C.numpy())
ax.set_ylim(0, 1.2)
ax.set_xlabel("Spinal position (m)")
ax.set_ylabel("Drug concentration")
ax.set_title("Intrathecal Drug Transport (PyTorch)")

# -------------------------
# Time stepping
# -------------------------
def update(frame):
    global C
    t = frame * dt
    v = v0 * torch.sin(2 * np.pi * freq * torch.tensor(t))

    C = C + dt * (
        D * laplacian(C)
        - v * gradient(C)
    )

    C = torch.clamp(C, min=0)  # physical constraint
    line.set_ydata(C.numpy())
    return line,

ani = FuncAnimation(fig, update, frames=Nt, interval=30)
plt.show()
