Code Optimization for Differential Equations
Code Optimization in Julia
Before starting this tutorial, we recommend the reader to check out one of the many tutorials for optimization Julia code. The following is an incomplete list:
User-side optimizations are important because, for sufficiently difficult problems, most time will be spent inside your f
function, the function you are trying to solve. “Efficient” integrators are those that reduce the required number of f
calls to hit the error tolerance. The main ideas for optimizing your DiffEq code, or any Julia function, are the following:
-
Make it non-allocating
-
Use StaticArrays for small arrays
-
Use broadcast fusion
-
Make it type-stable
-
Reduce redundant calculations
-
Make use of BLAS calls
-
Optimize algorithm choice
We’ll discuss these strategies in the context of differential equations. Let’s start with small systems.
Example Accelerating a Non-Stiff Equation: The Lorenz Equation
Let’s take the classic Lorenz system. Let’s start by naively writing the system in its out-of-place form:
function lorenz(u, p, t)
dx = 10.0 * (u[2] - u[1])
dy = u[1] * (28.0 - u[3]) - u[2]
dz = u[1] * u[2] - (8 / 3) * u[3]
[dx, dy, dz]
end
lorenz (generic function with 1 method)
Here, lorenz
returns an object, [dx,dy,dz]
, which is created within the body of lorenz
.
This is a common code pattern from high-level languages like MATLAB, SciPy, or R’s deSolve. However, the issue with this form is that it allocates a vector, [dx,dy,dz]
, at each step. Let’s benchmark the solution process with this choice of function:
using DifferentialEquations, BenchmarkTools
u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 100.0)
prob = ODEProblem(lorenz, u0, tspan)
@btime solve(prob, Tsit5());
3.795 ms (101102 allocations: 7.82 MiB)
The BenchmarkTools.jl
package’s @benchmark
runs the code multiple times to get an accurate measurement. The minimum time is the time it takes when your OS and other background processes aren’t getting in the way. Notice that in this case it takes about 5ms to solve and allocates around 11.11 MiB. However, if we were to use this inside of a real user code, we’d see a lot of time spent doing garbage collection (GC) to clean up all the arrays we made. Even if we turn off saving, we have these allocations.
@btime solve(prob, Tsit5(), save_everystep = false);
3.209 ms (89529 allocations: 6.83 MiB)
The problem, of course, is that arrays are created every time our derivative function is called. This function is called multiple times per step and is thus the main source of memory usage. To fix this, we can use the in-place form to make our code non-allocating:
function lorenz!(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
nothing
end
lorenz! (generic function with 1 method)
Here, instead of creating an array each time, we utilized the cache array du
. When the in-place form is used, DifferentialEquations.jl takes a different internal route that minimizes the internal allocations as well.
Notice that nothing is returned. When in in-place form, the ODE solver ignores the return. Instead, make sure that the original |
When we benchmark this function, we will see quite a difference.
u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 100.0)
prob = ODEProblem(lorenz!, u0, tspan)
@btime solve(prob, Tsit5());
762.725 μs (11415 allocations: 996.33 KiB)
@btime solve(prob, Tsit5(), save_everystep = false);
338.451 μs (49 allocations: 4.00 KiB)
There is a 16x time difference just from that change! Notice there are still some allocations and this is due to the construction of the integration cache. But this doesn’t scale with the problem size:
tspan = (0.0, 500.0) # 5x longer than before
prob = ODEProblem(lorenz!, u0, tspan)
@btime solve(prob, Tsit5(), save_everystep = false);
2.093 ms (49 allocations: 4.00 KiB)
Since that’s all setup allocations, the user-side optimization is complete.
Further Optimizations of Small Non-Stiff ODEs with StaticArrays
Allocations are only expensive if they are “heap allocations”. For a more in-depth definition of heap allocations, there are many sources online. But a good working definition is that heap allocations are variable-sized slabs of memory which have to be pointed to, and this pointer indirection costs time. Additionally, the heap has to be managed, and the garbage controllers has to actively keep track of what’s on the heap.
However, there’s an alternative to heap allocations, known as stack allocations. The stack is statically-sized (known at compile time) and thus its accesses are quick. Additionally, the exact block of memory is known in advance by the compiler, and thus re-using the memory is cheap. This means that allocating on the stack has essentially no cost!
Arrays have to be heap allocated because their size (and thus the amount of memory they take up) is determined at runtime. But there are structures in Julia which are stack-allocated. struct
s for example are stack-allocated "`value-type`"s. Tuple
s are a stack-allocated collection. The most useful data structure for DiffEq though is the StaticArray
from the package StaticArrays.jl. These arrays have their length determined at compile-time. They are created using macros attached to normal array expressions, for example:
using StaticArrays
A = SA[2.0, 3.0, 5.0]
typeof(A) # SVector{3, Float64} (alias for SArray{Tuple{3}, Float64, 1, 3})
SVector{3, Float64} (alias for SArray{Tuple{3}, Float64, 1, 3})
Notice that the 3
after SVector
gives the size of the SVector
. It cannot be changed. Additionally, SVector
s are immutable, so we have to create a new SVector
to change values. But remember, we don’t have to worry about allocations because this data structure is stack-allocated. SArray
s have numerous extra optimizations as well: they have fast matrix multiplication, fast QR factorizations, etc. which directly make use of the information about the size of the array. Thus, when possible, they should be used.
Unfortunately, static arrays can only be used for sufficiently small arrays. After a certain size, they are forced to heap allocate after some instructions and their compile time balloons. Thus, static arrays shouldn’t be used if your system has more than ~20 variables. Additionally, only the native Julia algorithms can fully utilize static arrays.
Let’s optimize lorenz
using static arrays. Note that in this case, we want to use the out-of-place allocating form, but this time we want to output a static array:
function lorenz_static(u, p, t)
dx = 10.0 * (u[2] - u[1])
dy = u[1] * (28.0 - u[3]) - u[2]
dz = u[1] * u[2] - (8 / 3) * u[3]
SA[dx, dy, dz]
end
lorenz_static (generic function with 1 method)
To make the solver internally use static arrays, we simply give it a static array as the initial condition:
u0 = SA[1.0, 0.0, 0.0]
tspan = (0.0, 100.0)
prob = ODEProblem(lorenz_static, u0, tspan)
@btime solve(prob, Tsit5());
355.294 μs (1293 allocations: 387.39 KiB)
@btime solve(prob, Tsit5(), save_everystep = false);
233.695 μs (22 allocations: 2.25 KiB)
And that’s pretty much all there is to it. With static arrays, you don’t have to worry about allocating, so use operations like *
and don’t worry about fusing operations (discussed in the next section). Do “the vectorized code” of R/MATLAB/Python and your code in this case will be fast, or directly use the numbers/values.
Example Accelerating a Stiff Equation: the Robertson Equation
For these next examples, let’s solve the Robertson equations (also known as ROBER):
Given that these equations are stiff, non-stiff ODE solvers like Tsit5
or Vern9
will fail to solve these equations. The automatic algorithm will detect this and automatically switch to something more robust to handle these issues. For example:
using DifferentialEquations
using Plots
function rober!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃
du[3] = k₂ * y₂^2
nothing
end
prob = ODEProblem(rober!, [1.0, 0.0, 0.0], (0.0, 1e5), [0.04, 3e7, 1e4])
sol = solve(prob)
plot(sol, tspan = (1e-2, 1e5), xscale = :log10)
using BenchmarkTools
@btime solve(prob);
121.170 μs (675 allocations: 58.75 KiB)
Choosing a Good Solver
Choosing a good solver is required for getting top-notch speed. General recommendations can be found on the solver page (for example, the ODE Solver Recommendations). The current recommendations can be simplified to a Rosenbrock method (Rosenbrock23
or Rodas5
) for smaller (<50 ODEs) problems, ESDIRK methods for slightly larger (TRBDF2
or KenCarp4
for <2000 ODEs), and QNDF
for even larger problems. lsoda
from LSODA.jl is sometimes worth a try for the medium-sized category.
More details on the solver to choose can be found by benchmarking. See the SciMLBenchmarks to compare many solvers on many problems.
From this, we try the recommendation of Rosenbrock23()
for stiff ODEs at default tolerances:
@btime solve(prob, Rosenbrock23());
100.705 μs (500 allocations: 40.64 KiB)
Declaring Jacobian Functions
In order to reduce the Jacobian construction cost, one can describe a Jacobian function by using the jac
argument for the ODEFunction
. First we have to derive the Jacobian which is J[i,j]
. From this, we get:
function rober_jac!(J, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
J[1, 1] = k₁ * -1
J[2, 1] = k₁
J[3, 1] = 0
J[1, 2] = y₃ * k₃
J[2, 2] = y₂ * k₂ * -2 + y₃ * k₃ * -1
J[3, 2] = y₂ * 2 * k₂
J[1, 3] = k₃ * y₂
J[2, 3] = k₃ * y₂ * -1
J[3, 3] = 0
nothing
end
f! = ODEFunction(rober!, jac = rober_jac!)
prob_jac = ODEProblem(f!, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 100000.0)
u0: 3-element Vector{Float64}:
1.0
0.0
0.0
@btime solve(prob_jac, Rosenbrock23());
86.306 μs (422 allocations: 34.64 KiB)
Automatic Derivation of Jacobian Functions
But that was hard! If you want to take the symbolic Jacobian of numerical code, we can make use of ModelingToolkit.jl to symbolic-ify the numerical code and do the symbolic calculation and return the Julia code for this.
using ModelingToolkit
de = modelingtoolkitize(prob)
Model ##MTKizedODE#15999 with 3 equations
States (3):
x₁(t) [defaults to 1.0]
x₂(t) [defaults to 0.0]
x₃(t) [defaults to 0.0]
Parameters (3):
α₁ [defaults to 0.04]
α₂ [defaults to 3.0e7]
α₃ [defaults to 10000.0]
We can tell it to compute the Jacobian if we want to see the code:
ModelingToolkit.generate_jacobian(de)[2] # Second is in-place
:(function (ˍ₋out, ˍ₋arg1, ˍ₋arg2, t)
#= /root/.julia/packages/SymbolicUtils/YVse6/src/code.jl:373 =#
#= /root/.julia/packages/SymbolicUtils/YVse6/src/code.jl:374 =#
#= /root/.julia/packages/SymbolicUtils/YVse6/src/code.jl:375 =#
begin
begin
begin
#= /root/.julia/packages/Symbolics/rvztO/src/build_function.jl:537 =#
#= /root/.julia/packages/SymbolicUtils/YVse6/src/code.jl:422 =# @inbounds begin
#= /root/.julia/packages/SymbolicUtils/YVse6/src/code.jl:418 =#
ˍ₋out[1] = (*)(-1, ˍ₋arg2[1])
ˍ₋out[2] = ˍ₋arg2[1]
ˍ₋out[3] = 0
ˍ₋out[4] = (*)(ˍ₋arg2[3], ˍ₋arg1[3])
ˍ₋out[5] = (+)((*)((*)(-2, ˍ₋arg2[2]), ˍ₋arg1[2]), (*)((*)(-1, ˍ₋arg2[3]), ˍ₋arg1[3]))
ˍ₋out[6] = (*)((*)(2, ˍ₋arg2[2]), ˍ₋arg1[2])
ˍ₋out[7] = (*)(ˍ₋arg2[3], ˍ₋arg1[2])
ˍ₋out[8] = (*)((*)(-1, ˍ₋arg2[3]), ˍ₋arg1[2])
ˍ₋out[9] = 0
#= /root/.julia/packages/SymbolicUtils/YVse6/src/code.jl:420 =#
nothing
end
end
end
end
end)
Now let’s use that to give the analytical solution Jacobian:
prob_jac2 = ODEProblem(de, [], (0.0, 1e5), jac = true)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 100000.0)
u0: 3-element Vector{Float64}:
1.0
0.0
0.0
@btime solve(prob_jac2);
107.581 μs (608 allocations: 59.53 KiB)
See the ModelingToolkit.jl documentation for more details.
Accelerating Small ODE Solves with Static Arrays
If the ODE is sufficiently small (<20 ODEs or so), using StaticArrays.jl for the state variables can greatly enhance the performance. This is done by making u0
a StaticArray
and writing an out-of-place non-mutating dispatch for static arrays, for the ROBER problem, this looks like:
using StaticArrays
function rober_static(u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du1 = -k₁ * y₁ + k₃ * y₂ * y₃
du2 = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃
du3 = k₂ * y₂^2
SA[du1, du2, du3]
end
prob = ODEProblem(rober_static, SA[1.0, 0.0, 0.0], (0.0, 1e5), SA[0.04, 3e7, 1e4])
sol = solve(prob, Rosenbrock23())
retcode: Success
Interpolation: specialized 2nd order "free" stiffness-aware interpolation
t: 61-element Vector{Float64}:
0.0
3.196206628740808e-5
0.00014400709336278452
0.00025605212043816096
0.00048593871402339607
0.0007179482102678373
0.0010819240251828343
0.0014801655107859655
0.0020679567717440095
0.002843584518457066
⋮
25371.93159838571
30784.11718374498
37217.42390396605
44850.61094811346
53893.688830057334
64593.73530179436
77241.71691097679
92180.81843146283
100000.0
u: 61-element Vector{SVector{3, Float64}}:
[1.0, 0.0, 0.0]
[0.9999987215181657, 1.2780900152625978e-6, 3.9181897521319503e-10]
[0.9999942397329006, 5.7185104612947566e-6, 4.175663804739006e-8]
[0.9999897579688383, 9.992106612572491e-6, 2.49924549040571e-7]
[0.9999805626683271, 1.7833623941038088e-5, 1.6037077316934769e-6]
[0.9999712826607852, 2.403488562731424e-5, 4.682453587410618e-6]
[0.9999567250114038, 3.0390689334989113e-5, 1.2884299261094982e-5]
[0.9999407986095145, 3.388427339038224e-5, 2.531711709494679e-5]
[0.9999172960310598, 3.583508669306405e-5, 4.686888224684217e-5]
[0.9998862913763157, 3.6412401619257426e-5, 7.729622206475401e-5]
⋮
[0.05563508171413305, 2.3546322394505495e-7, 0.9443646828226426]
[0.047925352159210115, 2.012149653947115e-7, 0.9520744466258239]
[0.04123342367542567, 1.7192789116847091e-7, 0.9587664043966821]
[0.03543700020207701, 1.4688362762022196e-7, 0.9645628529142937]
[0.03042537309965345, 1.2546809864160592e-7, 0.9695745014322467]
[0.026099133126498978, 1.071560882169501e-7, 0.9739007597174127]
[0.022369692367946337, 9.149845157822593e-8, 0.977630216133602]
[0.019158563494465593, 7.811096455346351e-8, 0.9808413583945704]
[0.017827893845894716, 7.258919980139027e-8, 0.982172033564906]
If we benchmark this, we see a really fast solution with really low allocation counts:
@btime sol = solve(prob, Rosenbrock23());
82.839 μs (807 allocations: 46.66 KiB)
This version is thus very amenable to multithreading and other forms of parallelism.
Example Accelerating Linear Algebra PDE Semi-Discretization
In this tutorial, we will optimize the right-hand side definition of a PDE semi-discretization.
We highly recommend looking at the Solving Large Stiff Equations tutorial for details on customizing DifferentialEquations.jl for more efficient large-scale stiff ODE solving. This section will only focus on the user-side code. |
Let’s optimize the solution of a Reaction-Diffusion PDE’s discretization. In its discretized form, this is the ODE:
where , , and are matrices. Here, we will use the simplified version where is the tridiagonal stencil ], i.e. it’s the 2D discretization of the Laplacian. The native code would be something along the lines of:
using DifferentialEquations, LinearAlgebra, BenchmarkTools
# Generate the constants
p = (1.0, 1.0, 1.0, 10.0, 0.001, 100.0) # a,α,ubar,β,D1,D2
N = 100
Ax = Array(Tridiagonal([1.0 for i in 1:(N - 1)], [-2.0 for i in 1:N],
[1.0 for i in 1:(N - 1)]))
Ay = copy(Ax)
Ax[2, 1] = 2.0
Ax[end - 1, end] = 2.0
Ay[1, 2] = 2.0
Ay[end, end - 1] = 2.0
function basic_version!(dr, r, p, t)
a, α, ubar, β, D1, D2 = p
u = r[:, :, 1]
v = r[:, :, 2]
Du = D1 * (Ay * u + u * Ax)
Dv = D2 * (Ay * v + v * Ax)
dr[:, :, 1] = Du .+ a .* u .* u ./ v .+ ubar .- α * u
dr[:, :, 2] = Dv .+ a .* u .* u .- β * v
end
a, α, ubar, β, D1, D2 = p
uss = (ubar + β) / α
vss = (a / β) * uss^2
r0 = zeros(100, 100, 2)
r0[:, :, 1] .= uss .+ 0.1 .* rand.()
r0[:, :, 2] .= vss
prob = ODEProblem(basic_version!, r0, (0.0, 0.1), p)
ODEProblem with uType Array{Float64, 3} and tType Float64. In-place: true
timespan: (0.0, 0.1)
u0: 100×100×2 Array{Float64, 3}:
[:, :, 1] =
11.0297 11.0356 11.0428 11.0467 … 11.0225 11.0233 11.038 11.0872
11.0549 11.0961 11.0168 11.096 11.0164 11.0972 11.0107 11.0099
11.015 11.0057 11.0514 11.0409 11.0082 11.023 11.0189 11.0118
11.0015 11.0915 11.0032 11.003 11.0114 11.0573 11.0665 11.0138
11.0602 11.0724 11.0952 11.0322 11.0306 11.0269 11.0808 11.0369
11.0913 11.011 11.0755 11.0714 … 11.0866 11.0974 11.0936 11.0343
11.0029 11.0355 11.093 11.0127 11.0059 11.0543 11.0804 11.096
11.0688 11.0521 11.0789 11.0599 11.0039 11.0845 11.0969 11.0574
11.0029 11.0872 11.0116 11.0584 11.0567 11.0728 11.0656 11.0137
11.0224 11.0414 11.0153 11.0092 11.0163 11.0839 11.0493 11.018
⋮ ⋱
11.0642 11.0425 11.0113 11.0711 11.0482 11.0105 11.0953 11.0753
11.0039 11.0969 11.0645 11.0514 11.023 11.0244 11.0764 11.0472
11.0903 11.0955 11.0155 11.0446 11.048 11.0918 11.0061 11.0577
11.0058 11.0153 11.0618 11.0181 11.0111 11.0618 11.0591 11.0679
11.0696 11.0355 11.0647 11.0514 … 11.0199 11.0414 11.0463 11.0582
11.0742 11.0107 11.0072 11.0221 11.0272 11.0425 11.0576 11.0433
11.091 11.0258 11.065 11.0152 11.0767 11.0518 11.0037 11.0915
11.1 11.0123 11.0879 11.0601 11.0283 11.0435 11.0621 11.0687
11.0969 11.071 11.0128 11.05 11.0859 11.0866 11.0665 11.057
[:, :, 2] =
12.1 12.1 12.1 12.1 12.1 12.1 … 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 … 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
⋮ ⋮ ⋱ ⋮
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 … 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1 12.1
In this version, we have encoded our initial condition to be a 3-dimensional array, with u[:,:,1]
being the A
part and u[:,:,2]
being the B
part.
@btime solve(prob, Tsit5());
64.948 ms (7349 allocations: 186.83 MiB)
While this version isn’t very efficient,
We recommend writing the “high-level” code first, and iteratively optimizing it!
The first thing that we can do is get rid of the slicing allocations. The operation r[:,:,1]
creates a temporary array instead of a “view”, i.e. a pointer to the already existing memory. To make it a view, add @view
. Note that we have to be careful with views because they point to the same memory, and thus changing a view changes the original values:
A = rand(4)
@show A
B = @view A[1:3]
B[2] = 2
@show A
4-element Vector{Float64}:
0.233875535898035
2.0
0.08848357049758093
0.41287083506350053
Notice that changing B
changed A
. This is something to be careful of, but at the same time we want to use this since we want to modify the output dr
. Additionally, the last statement is a purely element-wise operation, and thus we can make use of broadcast fusion there. Let’s rewrite basic_version!
to avoid slicing allocations and to use broadcast fusion:
function gm2!(dr, r, p, t)
a, α, ubar, β, D1, D2 = p
u = @view r[:, :, 1]
v = @view r[:, :, 2]
du = @view dr[:, :, 1]
dv = @view dr[:, :, 2]
Du = D1 * (Ay * u + u * Ax)
Dv = D2 * (Ay * v + v * Ax)
@. du = Du + a .* u .* u ./ v + ubar - α * u
@. dv = Dv + a .* u .* u - β * v
end
prob = ODEProblem(gm2!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());
59.937 ms (5879 allocations: 119.71 MiB)
Now, most of the allocations are taking place in Du = D1*(Ay*u + u*Ax)
since those operations are vectorized and not mutating. We should instead replace the matrix multiplications with mul!
. When doing so, we will need to have cache variables to write into. This looks like:
Ayu = zeros(N, N)
uAx = zeros(N, N)
Du = zeros(N, N)
Ayv = zeros(N, N)
vAx = zeros(N, N)
Dv = zeros(N, N)
function gm3!(dr, r, p, t)
a, α, ubar, β, D1, D2 = p
u = @view r[:, :, 1]
v = @view r[:, :, 2]
du = @view dr[:, :, 1]
dv = @view dr[:, :, 2]
mul!(Ayu, Ay, u)
mul!(uAx, u, Ax)
mul!(Ayv, Ay, v)
mul!(vAx, v, Ax)
@. Du = D1 * (Ayu + uAx)
@. Dv = D2 * (Ayv + vAx)
@. du = Du + a * u * u ./ v + ubar - α * u
@. dv = Dv + a * u * u - β * v
end
prob = ODEProblem(gm3!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());
44.045 ms (4703 allocations: 29.97 MiB)
But our temporary variables are global variables. We need to either declare the caches as const
or localize them. We can localize them by adding them to the parameters, p
. It’s easier for the compiler to reason about local variables than global variables. Localizing variables helps to ensure type stability.
p = (1.0, 1.0, 1.0, 10.0, 0.001, 100.0, Ayu, uAx, Du, Ayv, vAx, Dv) # a,α,ubar,β,D1,D2
function gm4!(dr, r, p, t)
a, α, ubar, β, D1, D2, Ayu, uAx, Du, Ayv, vAx, Dv = p
u = @view r[:, :, 1]
v = @view r[:, :, 2]
du = @view dr[:, :, 1]
dv = @view dr[:, :, 2]
mul!(Ayu, Ay, u)
mul!(uAx, u, Ax)
mul!(Ayv, Ay, v)
mul!(vAx, v, Ax)
@. Du = D1 * (Ayu + uAx)
@. Dv = D2 * (Ayv + vAx)
@. du = Du + a * u * u ./ v + ubar - α * u
@. dv = Dv + a * u * u - β * v
end
prob = ODEProblem(gm4!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());
37.362 ms (1028 allocations: 29.66 MiB)
We could then use the BLAS gemmv
to optimize the matrix multiplications some more, but instead let’s devectorize the stencil.
p = (1.0, 1.0, 1.0, 10.0, 0.001, 100.0, N)
function fast_gm!(du, u, p, t)
a, α, ubar, β, D1, D2, N = p
@inbounds for j in 2:(N - 1), i in 2:(N - 1)
du[i, j, 1] = D1 *
(u[i - 1, j, 1] + u[i + 1, j, 1] + u[i, j + 1, 1] + u[i, j - 1, 1] -
4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
end
@inbounds for j in 2:(N - 1), i in 2:(N - 1)
du[i, j, 2] = D2 *
(u[i - 1, j, 2] + u[i + 1, j, 2] + u[i, j + 1, 2] + u[i, j - 1, 2] -
4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
end
@inbounds for j in 2:(N - 1)
i = 1
du[1, j, 1] = D1 *
(2u[i + 1, j, 1] + u[i, j + 1, 1] + u[i, j - 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
end
@inbounds for j in 2:(N - 1)
i = 1
du[1, j, 2] = D2 *
(2u[i + 1, j, 2] + u[i, j + 1, 2] + u[i, j - 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
end
@inbounds for j in 2:(N - 1)
i = N
du[end, j, 1] = D1 *
(2u[i - 1, j, 1] + u[i, j + 1, 1] + u[i, j - 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
end
@inbounds for j in 2:(N - 1)
i = N
du[end, j, 2] = D2 *
(2u[i - 1, j, 2] + u[i, j + 1, 2] + u[i, j - 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
end
@inbounds for i in 2:(N - 1)
j = 1
du[i, 1, 1] = D1 *
(u[i - 1, j, 1] + u[i + 1, j, 1] + 2u[i, j + 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
end
@inbounds for i in 2:(N - 1)
j = 1
du[i, 1, 2] = D2 *
(u[i - 1, j, 2] + u[i + 1, j, 2] + 2u[i, j + 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
end
@inbounds for i in 2:(N - 1)
j = N
du[i, end, 1] = D1 *
(u[i - 1, j, 1] + u[i + 1, j, 1] + 2u[i, j - 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
end
@inbounds for i in 2:(N - 1)
j = N
du[i, end, 2] = D2 *
(u[i - 1, j, 2] + u[i + 1, j, 2] + 2u[i, j - 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
end
@inbounds begin
i = 1
j = 1
du[1, 1, 1] = D1 * (2u[i + 1, j, 1] + 2u[i, j + 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
du[1, 1, 2] = D2 * (2u[i + 1, j, 2] + 2u[i, j + 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
i = 1
j = N
du[1, N, 1] = D1 * (2u[i + 1, j, 1] + 2u[i, j - 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
du[1, N, 2] = D2 * (2u[i + 1, j, 2] + 2u[i, j - 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
i = N
j = 1
du[N, 1, 1] = D1 * (2u[i - 1, j, 1] + 2u[i, j + 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
du[N, 1, 2] = D2 * (2u[i - 1, j, 2] + 2u[i, j + 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
i = N
j = N
du[end, end, 1] = D1 * (2u[i - 1, j, 1] + 2u[i, j - 1, 1] - 4u[i, j, 1]) +
a * u[i, j, 1]^2 / u[i, j, 2] + ubar - α * u[i, j, 1]
du[end, end, 2] = D2 * (2u[i - 1, j, 2] + 2u[i, j - 1, 2] - 4u[i, j, 2]) +
a * u[i, j, 1]^2 - β * u[i, j, 2]
end
end
prob = ODEProblem(fast_gm!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());
7.813 ms (440 allocations: 29.62 MiB)
Notice that in this case fusing the loops and avoiding the linear operators is a major improvement of about 10x! That’s an order of magnitude faster than our original MATLAB/SciPy/R vectorized style code!
Since this is tedious to do by hand, we note that ModelingToolkit.jl’s symbolic code generation can do this automatically from the basic version:
using ModelingToolkit
function basic_version!(dr, r, p, t)
a, α, ubar, β, D1, D2 = p
u = r[:, :, 1]
v = r[:, :, 2]
Du = D1 * (Ay * u + u * Ax)
Dv = D2 * (Ay * v + v * Ax)
dr[:, :, 1] = Du .+ a .* u .* u ./ v .+ ubar .- α * u
dr[:, :, 2] = Dv .+ a .* u .* u .- β * v
end
a, α, ubar, β, D1, D2 = p
uss = (ubar + β) / α
vss = (a / β) * uss^2
r0 = zeros(100, 100, 2)
r0[:, :, 1] .= uss .+ 0.1 .* rand.()
r0[:, :, 2] .= vss
prob = ODEProblem(basic_version!, r0, (0.0, 0.1), p)
de = modelingtoolkitize(prob)
# Note jac=true,sparse=true makes it automatically build sparse Jacobian code
# as well!
fastprob = ODEProblem(de, [], (0.0, 0.1), jac = true, sparse = true)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 0.1)
u0: 20000-element Vector{Float64}:
11.00511096022425
11.02154866401249
11.093170305404836
11.054832027492262
11.07510929869822
11.080173590728513
11.02244672066665
11.074092575697362
11.073488249322759
11.034784443024972
⋮
12.100000000000001
12.100000000000001
12.100000000000001
12.100000000000001
12.100000000000001
12.100000000000001
12.100000000000001
12.100000000000001
12.100000000000001
Lastly, we can do other things like multithread the main loops. LoopVectorization.jl provides the @turbo
macro for doing a lot of SIMD enhancements, and @tturbo
is the multithreaded version.
Optimizing Algorithm Choices
The last thing to do is then optimize our algorithm choice. We have been using Tsit5()
as our test algorithm, but in reality this problem is a stiff PDE discretization and thus one recommendation is to use CVODE_BDF()
. However, instead of using the default dense Jacobian, we should make use of the sparse Jacobian afforded by the problem. The Jacobian is the matrix , where is read by the linear index (i.e. down columns). But since the variables depend on the , the band size here is large, and thus this will not do well with a Banded Jacobian solver. Instead, we utilize sparse Jacobian algorithms. CVODE_BDF
allows us to use a sparse Newton-Krylov solver by setting linear_solver = :GMRES
.
The Solving Large Stiff Equations tutorial goes through these details. This is simply to give a taste of how much optimization opportunity is left on the table! |
Let’s see how our fast right-hand side scales as we increase the integration time.
prob = ODEProblem(fast_gm!, r0, (0.0, 10.0), p)
@btime solve(prob, Tsit5());
279.728 s (39314 allocations: 2.76 GiB)
using Sundials
@btime solve(prob, CVODE_BDF(linear_solver = :GMRES));
1.230 s (13791 allocations: 121.90 MiB)
prob = ODEProblem(fast_gm!, r0, (0.0, 100.0), p)
# Will go out of memory if we don't turn off `save_everystep`!
@btime solve(prob, Tsit5(), save_everystep = false);
7.607 s (68 allocations: 2.90 MiB)
@btime solve(prob, CVODE_BDF(linear_solver = :GMRES), save_everystep = false);
3.360 s (33712 allocations: 2.39 MiB)
prob = ODEProblem(fast_gm!, r0, (0.0, 500.0), p)
@btime solve(prob, CVODE_BDF(linear_solver = :GMRES), save_everystep = false);
6.055 s (55128 allocations: 3.33 MiB)
Notice that we’ve eliminated almost all allocations, allowing the code to grow without hitting garbage collection and slowing down.
Why is CVODE_BDF
doing well? What’s happening is that, because the problem is stiff, the number of steps required by the explicit Runge-Kutta method grows rapidly, whereas CVODE_BDF
is taking large steps. Additionally, the GMRES
linear solver form is quite an efficient way to solve the implicit system in this case. This is problem-dependent, and in many cases using a Krylov method effectively requires a preconditioner, so you need to play around with testing other algorithms and linear solvers to find out what works best with your problem.
Now continue to the Solving Large Stiff Equations tutorial for more details on optimizing the algorithm choice for such codes.