# A script to illustrate each of Euler's method, the improved Euler method, and RK4
# for a system (also works on a scalar equation).
# Key points:
# The spatial variables are y[0],...,y[d-1] for a system with d dependent variables.
# t0 is the initial time, T is the final time
# The vector "f" is an expression involving t,y[0],...,y[d-1] that defines y'=f(t,y)
# Define f as below (scalar and vector examples shown).
# Solution times are returned in row vector "tv"
# Solution values are return in matrix "yv". Rows correspond to times, columns
# to variables y[0],...,y[d-1].

reset();

# SYSTEM EXAMPLE: Define t and y, define f in y' = f(t,y). Initial condition is y(t0) = ic.
var('t y');
def f(t,y):
    return vector([-sin(t)*y[1],-y[0]+cos(y[1])])
ic = vector(SR,[1.0,2.0]);

#SCALAR EXAMPLE: Uncomment to test.
#def f(t,y):
#    return vector([-y[0]])
#ic = vector(SR,[1.0]);

# Set solution interval t0 <= t <= T, and stepsize.
t0 = 0.0;
T = 5.0;
stepsize = 0.1;

#Call routine for Euler's method
load('euler_sys.sage');
[tv,yv] = euler_sys(f,ic,t0,T,stepsize)

#Call routine for improved Euler's method.
load('improved_euler_sys.sage');
[tv2,yv2] = improved_euler_sys(f,ic,t0,T,stepsize)

#Call routine for RK4 method.
load('rk4_sys.sage');
[tv3,yv3] = rk4_sys(f,ic,t0,T,stepsize)

#Print solution values at final time.
N = len(tv)-1;
d = len(ic);
print('Euler Method');
print('time',tv[N]);
for j in range(d):
    print('y[',j,'] = ',yv[N,j]);

print('Improved Euler Method');
print('time',tv2[N]);
for j in range(d):
    print('y[',j,'] = ',yv2[N,j]);

print('RK4 Method');
print('time',tv3[N]);
for j in range(d):
    print('y[',j,'] = ',yv3[N,j]);

#Plot first component of solution versus time, as simple example.
yp = yv.column(0);
p1 = line(list(zip(tv,yp)),rgbcolor=[1,0,0],legend_label='Euler method')
yp = yv2.column(0);
p2 = line(list(zip(tv2,yp)),rgbcolor=[0,1,0],legend_label='Improved Euler')
yp = yv3.column(0);
p3 = line(list(zip(tv3,yp)),rgbcolor=[0,0,1],legend_label='RK4')
p = p1+p2+p3;
p.set_legend_options(loc='upper right');
p.axes_labels(['$t$','$y0(t)$'])
show(p)