---
jupytext:
  formats: ipynb,md:myst
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.13.6
kernelspec:
  display_name: Python 3 (phys-581)
  language: python
  name: phys-581
---

```{code-cell}
:tags: [hide-cell]

import mmf_setup;mmf_setup.nbinit()
import logging;logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
%matplotlib inline
import numpy as np, matplotlib.pyplot as plt
```

(sec:RK4)=
# Runge-Kutta Truncation Error

Consider the Runge-Kutta RK4 formulation given in {cite:p}`Gezerlis:2023` for evolving
$y'(t) = f(t, y)$:
\begin{align*}
  k_0 &= h f(t, y)\\
  k_1 &= h f\Bigl(t+\tfrac{1}{2}h, y+\tfrac{1}{2}k_0\Bigr)\\
  k_2 &= h f\Bigl(t+\tfrac{1}{2}h, y+\tfrac{1}{2}k_1\Bigr) \tag{8.64}\\
  k_3 &= h f\Bigl(t+h, y+k_2\Bigr)\\
  y(t+h) &\approx y + \frac{k_0 + 2k_1 + 2k_2 + k_3}{6}.
\end{align*}
Tracing through, one can expand the final step $y_{j+1} \approx y(t+h)$ to obtaining an
expression in terms of $y'(t) = f(t, y)$ and various derivatives.  This expression can be simplified
by noting:
\begin{align*}
  y'(t) &= f\\
  y''(t) &= f_{,t} + f_{,y}y',\\
  y'''(t) &= f_{,tt} + 2f_{,ty}y' + f_{,yy}(y')^2 + f_{,y}y'',\\
  y''''(t) &= f_{,ttt} + 3f_{,tty}y' + 3f_{,tyy}(y')^2 + f_{,yyy}(y')^3
           + 2f_{,ty}y'' + 3f_{,yy}y'y'' + f_{,ty} y'' + f_{,y}y'''
\end{align*}

*(Incomplete.  Probably Mathematica or Maple can do a better job.)*

In principle, the following code should be able to give the correct error formula, but
SymPy has problems simplifying the various expressions.
```{code-cell}
from sympy import *
t, h = var('t, h', real=True)
f = Function('f', real=True)
y = Function('y', real=True)
dy = f(t, y(t))
d2y = dy.diff(t)
d3y = d2y.diff(t)
d4y = d3y.diff(t)
k0 = h*f(t, y(t))
k1 = h*f(t+h/2,  y(t)+k0/2)
k2 = h*f(t+h/2,  y(t)+k1/2)
k3 = h*f(t+h,  y(t)+k2)
y1 = y(t) + (k0 + 2*k1 + 2*k2 + k3)/6
err = ((y(t+h) - y1).series(h, 0, 5).simplify()
                    .replace(y(t).diff(t,t,t,t), d4y).simplify()
                    .replace(y(t).diff(t,t,t), d3y).simplify()
                    .replace(y(t).diff(t,t), d2y).simplify()
                    .replace(diff(y(t), t), dy).simplify())
display(err)
err.coeff(h, 3).doit().simplify()
```
