---
execution:
  timeout: 300
jupytext:
  notebook_metadata_filter: all
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.13.0
kernelspec:
  display_name: Python 3 (phys-581)
  language: python
  name: phys-581
---

```{code-cell} ipython3
:tags: [hide-cell]

import mmf_setup

mmf_setup.nbinit()
import logging

logging.getLogger("matplotlib").setLevel(logging.CRITICAL)
%matplotlib inline
import numpy as np, matplotlib.pyplot as plt
```

(sec:SEQ)=
# Solving The Schrödinger Equation

We assume you have worked through {ref}`sec:ODEs` and are familiar with the approach
taken there to solve the Schrödinger equation.  Here we focus on developing a good set
of software for solving this problem more generally.

Here we present a comprehensive worked example of developing a framework for solving the
Schrödinger equation.  This follows a lot of structures and methodologies developed in
our [GPE project](https://gpe.readthedocs.io/en/latest/) which solves the
Gross-Pitaevskii Equation (GPE) or non-linear Schrödinger Equation (NLSEQ) for
superfluid dynamics.

## Test Problem

Following recommendation [PP41: Test to Code][], we first would like to identify a test
problem.  We choose the 1D quantum harmonic oscillator since we know it has exact
solutions:
\begin{gather*}
  \I\hbar \dot{\psi}(x, t) = \frac{-\hbar^2\psi''(x, t)}{2m} + \frac{m\omega^2}{2}x^2\psi(x).
\end{gather*}
This problem is characterized by the following scales, which we take to be unity,
defining our units:
* $T = \frac{2\pi}{\omega}$: The trap period, which defines our units of time.
* $a_0 = \sqrt{\hbar/m\omega}$: The trap length, which defines our units of distance.
* $m$: The mass of the particle, which defines our units of mass.

Some facts that are easily checked by substitution into the Schrödinger equation are
that the ground state is
\begin{gather*}
  \psi_0(x, t) \propto e^{E_0 t/\I\hbar} e^{-x^2/2a_0^2}, \qquad
  E_0 = \tfrac{1}{2}\hbar \omega.
\end{gather*}
Another useful fact is that all motion is periodic with period $T$.  Thus, for any state
$\psi(x, t)$, we have $\psi(x, t+nT) = \psi(x, t)$.  This provides us with an easy
test-case.  We will evolve some initial state -- say the shifted ground state -- for one
period, and see how close we get to where we started.

As a byproduct, by trying to write this test, we might gain some insight into how we
would like to design our Schrödinger equation solver.
:::{doit} Write a test function.

Write a function to test your desired code using the above invariant.  How would you
like to interact with your code?  What functions do you need?
:::

```{code-cell}
from scipy.integrate import solve_ivp


def test_ho(seq, x0=2.0):
    a0 = np.sqrt(seq.hbar / seq.m / seq.w)
    T = 2*np.pi / seq.w
    x = seq.x
    psi0 = np.exp(-((x-x0)/a0)**2/2)
    psi1 = seq.evolve(psi0, t=T)
    assert np.allclose(abs(psi0)**2, abs(psi1)**2, rtol=1e-4, atol=1e-4)


class SEQ:
    """Class to help solve the Schrodinger equation for a harmonic trap.
    
    Attributes
    ----------
    hbar, m, T, w : float
        Various physical constants for the system.
    N : int
        Number of points to include in the lattice
    L : float
        Box size
    
    """
    hbar = 1.0
    T = 1.0
    m = 1.0
    w = 2*np.pi / T

    def __init__(self, N=64, L=10.0):
        self.N = N
        self.L = 10.0
        self.dx = self.L/self.N
        self.x = np.arange(self.N) * self.dx - self.L/2.0
    
    def evolve(self, psi, t, **kw):
        """Return `psi(t)` evolving `psi` for time `t`."""
        return psi  # Wrong - but quickly get our tests to pass.


test_ho(SEQ())
```

Now let's do a non-trivial implementation based on {ref}`sec:ODEs`:


```{code-cell}
class SEQ1(SEQ):
    def __init__(self, **kw):
        super().__init__(**kw)
        self.init()
    
    def init(self):
        D2 = self.get_D2()
        K = -self.hbar**2 / 2 / self.m * D2
        Vx = (self.w * self.x)**2 / 2 / self.m
        V = np.diag(Vx)
        self.H = K + V
        
    def get_D2(self):
        """Return a matrix approximation for the laplacian."""
        ones = np.ones(self.N)
        D2 = (
            np.diag(ones[1:], 1) 
            + np.diag(ones[1:], -1) 
            - 2*np.diag(ones)
            ) / self.dx**2
        return D2
        
    def compute_dy_dt(self, t, psi):
        """Return dpsi_dt."""
        Hpsi = self.H @ psi
        dpsi_dt = Hpsi / (1j * self.hbar)
        return dpsi_dt
        
    def evolve(self, psi, t, method="DOP853", **kw):
        """Return `psi(t)` evolving `psi` for time `t`."""
        res = solve_ivp(self.compute_dy_dt, y0=psi+0j, t_span=(0, t), 
                        method=method, **kw)
        assert res.success
        psi1 = res.y[:, -1]
        self._psi0, self._psi1 = psi, psi1
        return psi1


test_ho(SEQ1(N=128))
```

To better understand why the test is not passing, let's compute the errors:

```{code-cell}
def get_err(N=256, L=10.0, x0=2.0, SEQ=SEQ1, **kw):
    s = SEQ(N=N, L=L)
    a = np.sqrt(s.hbar / s.m / s.w)
    psi0 = np.exp(-((s.x-x0)/a)**2/2)
    psi1 = s.evolve(psi0, t=s.T, **kw)
    return (abs(psi0)**2 - abs(psi1)**2).max()


Ns = 2**np.arange(2, 9)
errs = [get_err(N=N, SEQ=SEQ1) for N in Ns]
```

Perhaps a higher order stencil will help.  Here we try the 5-point stencil from
{ref}`sec:Derivatives`:
```{code-cell}
class SEQ2(SEQ1):
    def get_D2(self):
        """Return a matrix approximation for the laplacian."""
        # 5-point stencil
        stencil = np.array([-1, 16, -30, 16, -1])/12
        k = [-2, -1, 0, 1, 2]
        ones = np.ones(self.N)
        D2 = np.sum([s * np.diag(ones[abs(k):], k)
                     for (s, k) in zip(stencil, k)], axis=0)
        return D2 / self.dx**2
        

test_ho(SEQ2(N=256, L=10.0), x0=1.0)
```
This works, but we should probably think a little about why.  The IR errors are due to
the box, and should be largest at the boundary, so we can estimate the error to be
roughly: 
\begin{gather*}
   \epsilon \sim e^{-(L/2 - x_0)^2/a_0^2}.
\end{gather*}
For $L=10$ and $x_0=2.0$, this is below $3\times 10^{-25}$ so we should be fine, and we
can probably reduce the box size to $L=8$ without any worry.

:::{margin}
Note that this is exactly the same as the error obtained by comparing the modification
of the kinetic energy in {ref}`sec:ODEs`:
\begin{gather*}
  K(p) = \frac{p^2}{2m} - \frac{p^4 (\d{x})^2}{24m\hbar^2} + \cdots,\\
  \epsilon_{\text{rel}} \sim \frac{\frac{p^4 (\d{x})^2}{24m\hbar^2}}{\frac{p^2}{2m}}
  = \frac{p^2 L^2 }{12\hbar^2 N^2}.
\end{gather*}
:::
To estimate the UV errors, we note that the truncation error in our approximation of the
derivative is
\begin{gather*}
  f''(x) = \frac{f(x-h) + f(x+h) - 2f(x)}{h^2} + \frac{h^2}{12}f^{(4)}(x^*).
\end{gather*}
To estimate $f^{(4)}$ we note that a particle moving with definite momentum $p$ will
have wavefunction $\psi(x) \propto e^{\I p x/\hbar}$, so $\psi^{(n)}(x) \sim
(p/\hbar)^n$.  The relative error is the ratio of the two terms, thus
\begin{gather*}
  \epsilon_{\text{rel}} \sim \frac{h^2}{12}\frac{p^2}{\hbar^2} 
  = \frac{p^2L^2}{12\hbar^2N^2}.
\end{gather*}
We can estimate the momentum from conservation of energy.  With $x_0=2.0$ and $L=10.0$,
we have
\begin{gather*}
  p \sim m\omega x_0 \approx 13, \qquad
  \epsilon \approx \frac{1300}{N^2}.
\end{gather*}
Thus, to get $\epsilon \sim 10^{-4}$, we would need $N \gtrsim 3600$ points.  Repeating
this analysis for the 5-point stencil gives
\begin{gather*}
  \epsilon_{\text{rel}} \sim \frac{h^4}{90}\frac{p^4}{\hbar^4} 
  = \frac{p^4L^4}{90\hbar^4 N^4},
\end{gather*}
hence, reducing $x_0 = 1.0$ and using $N=256$ points gives $\epsilon \sim 4\times
10^{-5}$: our test passes.


## Spectral Methods

:::{margin}
One must be a little careful to ensure that the conventions match those implemented by
the numerical FFT.  Our conventions state here match if we include the factor of $1/N$
in the inverse transform, and the phase $L/2$.
:::
We can get a much better approximation for the derivative by using the Fourier transform
(see {ref}`sec:FourierTechniques`).  The idea here is to write $\psi(x)$ as a sum of
plane-waves $e^{\I k x}$ so that, to compute the second derivative, we simply multiply
by $-k_n^2$ in momentum space.
\begin{gather*}
  D_2\psi = \mathcal{F}^{-1}\bigl(-k^2\mathcal{F}(\psi)\bigr),\\
  \psi_{m} = \mathcal{F}\bigl(\psi(x_n)\bigr)
            = \sum_{n=0}^{N-1} e^{-\I k_m (x_{n}+L/2)}\psi(x_n),\\
  \psi(x_n) = \mathcal{F}^{-1}(\psi_{m}) 
            = \frac{1}{N}\sum_{m=0}^{N-1}\psi_m e^{\I k_m (x_{n}+L/2)},\\
  x_{n} = n\d{x} - L/2, \qquad \d{x} = \frac{L}{N}.
\end{gather*}
The only thing remaining is to get the appropriate $k_m$ in the correct order, which we
do by calling {func}`numpy.fft.fftfreq`.
```{code-cell}
N = 64
L = 18.0
dx = L/N
n = np.arange(N)
x = n * dx - L/2
k = 2*np.pi * np.fft.fftfreq(N, dx)

def diff(f, d=2):
    """Return the dth derivative of f."""
    return np.fft.ifft((1j*k)**d * np.fft.fft(f))
```
```{code-cell}
:tags: [margin, hide-input]
# Test the gausian
f = np.exp(-x**2/2)
ddf_exact = (x**2 - 1)*f
ddf = diff(f, 2).real
fig, axs = plt.subplots(3, 1, sharex=True)
ax = axs[0]
ax.semilogy(x, f)
ax.set(ylabel="$f(x)$")
ax = axs[1]
ax.plot(x, ddf, '.', label="FFT")
ax.plot(x, ddf_exact, '-', label="Exact")
ax.set(ylabel="$f''(x)$")
ax.legend();
ax = axs[2]
ax.plot(x, ddf-ddf_exact)
ax.set(ylabel="Error");
```
We see in the margin that we can reach machine precision with $N=64$ points with a large
enough box $L=18$.  This is about optimal for this problem as we can see by computing
the analytic Fourier transform of our gaussian:
\begin{gather*}
  f(x) = e^{-x^2/2}, \qquad
  \mathcal{F}(f) = \int e^{-\I k x}e^{-x^2/2}\d{x} =  \sqrt{2\pi} e^{-k^2/2}.
\end{gather*}
To achieve precision $\epsilon = 2^{-52} \approx 2\times 10^{-16}$, these functions must
drop by this amount at the edge of the box $x=\pm L/2$ and for the largest wave-numbers
$k = \pm \pi N/L$:
\begin{gather*}
  \epsilon \approx e^{-L^2/8} \approx e^{-\pi^2N^2/2L^2}\\
  L \gtrapprox \sqrt{-8\ln \epsilon} \approx 17, \qquad
  N \gtrapprox \frac{L}{\pi}\sqrt{-2\ln \epsilon}
  \approx \frac{-4}{\pi}\ln \epsilon \approx 46.
\end{gather*}
:::{note}
These estimates are a little low for $f''(x)$ where we should use
\begin{gather*}
  f''(x) = (x^2-1)e^{-x^2/2}, \qquad
  \mathcal{F}(f'') = -k^2\sqrt{2\pi} e^{-k^2/2}.
\end{gather*}
Evaluating these at the edge of the box gives $L, 2k \gtrapprox 18$ and $N \gtrapprox
52$.  Playing a bit with the code, to achieve maximum precision we really need $N
\gtrapprox 64$ -- do you see why?
:::






[PP19: Version Control]: <https://research.ebsco.com/plink/27d11cab-6c6d-30a9-96a5-1c641b0fb342>
[PP20: Debugging]: <https://research.ebsco.com/plink/c77495b4-50ac-39c9-9b8e-7a869c79b401a>
[PP23: Design by Contract]: <https://research.ebsco.com/plink/edd2ed7b-3a73-3cc7-b98d-4aa1b8fdea4d>
[PP25: Assertive Programming]: <https://research.ebsco.com/plink/528d9b1d-b898-3f25-8a12-40d53542a609>
[PP28: Decoupling]: <https://research.ebsco.com/plink/a7c94c21-2be0-3932-a392-ee63296f8afc>
[PP31: Inheritance Tax]: <https://research.ebsco.com/plink/4cdd8d35-6e25-3851-ad1d-e04e3bc8a7c8>
[PP41: Test to Code]: <https://research.ebsco.com/plink/6fb62a24-cc5f-31c7-9888-023f450bc87f>
