import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# -----------------------------
# Domain parameters
# -----------------------------
L = 0.6              # spinal length (m)
R_outer = 0.01       # dura radius (m)
R_inner = 0.004      # spinal cord radius (m)

Nx = 400
Nr = 120

dx = L / Nx
dr = R_outer / Nr

D = 1e-4
v0 = 0.015
freq = 1.2

dt = 0.001
steps = 1500

# -----------------------------
# Grid
# -----------------------------
x = torch.linspace(0, L, Nx)
r = torch.linspace(0, R_outer, Nr)
X, R = torch.meshgrid(x, r, indexing="ij")

# -----------------------------
# CSF mask
# -----------------------------
csf_mask = (R >= R_inner) & (R <= R_outer)

# -----------------------------
# Initial concentration (CLEAR injection)
# -----------------------------
C = torch.zeros((Nx, Nr))

inj_x = int(0.2 * Nx)
C[inj_x-3:inj_x+3, :] = 1.0   # inject across CSF thickness
C *= csf_mask

# -----------------------------
# Operators
# -----------------------------
def laplacian(C):
    d2x = (torch.roll(C, 1, 0) - 2*C + torch.roll(C, -1, 0)) / dx**2
    d2r = (torch.roll(C, 1, 1) - 2*C + torch.roll(C, -1, 1)) / dr**2
    return d2x + d2r

def grad_x(C):
    return (torch.roll(C, -1, 0) - torch.roll(C, 1, 0)) / (2*dx)

# -----------------------------
# Visualization setup
# -----------------------------
fig, ax = plt.subplots(figsize=(10, 3))

img = ax.imshow(
    C.T.numpy(),
    extent=[0, L, 0, R_outer],
    origin="lower",
    aspect="auto",
    vmin=0,
    vmax=1
)

ax.fill_between(
    x.numpy(),
    0,
    R_inner,
    color="black",
    alpha=0.7,
    label="Spinal Cord"
)

ax.set_xlabel("Spinal length (m)")
ax.set_ylabel("Radius (m)")
ax.set_title("Intrathecal Drug Injection and Transport")
fig.colorbar(img, ax=ax, label="Drug concentration")

# -----------------------------
# Time stepping + animation
# -----------------------------
def update(frame):
    global C
    t = frame * dt
    v = v0 * torch.sin(2 * np.pi * freq * torch.tensor(t))

    C = C + dt * (
        D * laplacian(C)
        - v * grad_x(C)
    )

    C *= csf_mask
    C = torch.clamp(C, min=0)

    img.set_data(C.T.numpy())
    return [img]

ani = FuncAnimation(fig, update, frames=steps, interval=25)
plt.show()
