Автоматизированный разреженный параллелизм функций Julia за счет трассировки

Поскольку выражения Symbolics.jl следуют семантике Julia, можно напрямую преобразовывать существующие функции Julia в символьные представления функций Symbolics.jl, просто вводя символьные значения в функцию и используя возвращаемый результат. Например, возьмем следующую численную дискретизацию PDE:

using Symbolics, LinearAlgebra, SparseArrays, Plots

# Определение констант для PDE
const α₂ = 1.0
const α₃ = 1.0
const β₁ = 1.0
const β₂ = 1.0
const β₃ = 1.0
const r₁ = 1.0
const r₂ = 1.0
const _DD = 100.0
const γ₁ = 0.1
const γ₂ = 0.1
const γ₃ = 0.1
const N = 32
const X = reshape([i for i in 1:N for j in 1:N], N, N)
const Y = reshape([j for i in 1:N for j in 1:N], N, N)
const α₁ = 1.0 .* (X .>= 4*N/5)

const Mx = Array(Tridiagonal([1.0 for i in 1:N-1], [-2.0 for i in 1:N], [1.0 for i in 1:N-1]))
const My = copy(Mx)
Mx[2, 1] = 2.0
Mx[end-1,end] = 2.0
My[1, 2] = 2.0
My[end,end-1] = 2.0

# Определение дискретизированного PDE как функции ODE
function f(u, p, t)
    A = u[:,:,1]
    B = u[:,:,2]
    C = u[:,:,3]
    MyA = My*A
    AMx = A*Mx
    DA = @. _DD*(MyA + AMx)
    dA = @. DA + α₁ - β₁*A - r₁*A*B + r₂*C
    dB = @. α₂ - β₂*B - r₁*A*B + r₂*C
    dC = @. α₃ - β₃*C + r₁*A*B - r₂*C
    cat(dA, dB, dC, dims=3)
end
f (generic function with 1 method)

Мы можем создать версию этой модели в Symbolics, выполнив трассировку функции модели:

# Определение начального условия как обычных массивов
@variables u[1:N, 1:N, 1:3]
du = simplify.(f(collect(u), nothing, 0.0))
vec(du)[1:10]
10-element Vector{Num}:
  -u[1, 1, 1] + 100.0(-4.0u[1, 1, 1] + 2.0(u[1, 2, 1] + u[2, 1, 1])) + u[1, 1, 3] - u[1, 1, 1]*u[1, 1, 2]
                100.0(u[1, 1, 1] - 4.0u[2, 1, 1] + 2.0u[2, 2, 1] + u[3, 1, 1]) - u[2, 1, 1] + u[2, 1, 3] - u[2, 1, 1]*u[2, 1, 2]
                100.0(u[2, 1, 1] - 4.0u[3, 1, 1] + 2.0u[3, 2, 1] + u[4, 1, 1]) - u[3, 1, 1] + u[3, 1, 3] - u[3, 1, 1]*u[3, 1, 2]
                100.0(u[3, 1, 1] - 4.0u[4, 1, 1] + 2.0u[4, 2, 1] + u[5, 1, 1]) - u[4, 1, 1] + u[4, 1, 3] - u[4, 1, 1]*u[4, 1, 2]
                100.0(u[4, 1, 1] - 4.0u[5, 1, 1] + 2.0u[5, 2, 1] + u[6, 1, 1]) - u[5, 1, 1] + u[5, 1, 3] - u[5, 1, 1]*u[5, 1, 2]
                100.0(u[5, 1, 1] - 4.0u[6, 1, 1] + 2.0u[6, 2, 1] + u[7, 1, 1]) - u[6, 1, 1] + u[6, 1, 3] - u[6, 1, 1]*u[6, 1, 2]
                100.0(u[6, 1, 1] - 4.0u[7, 1, 1] + 2.0u[7, 2, 1] + u[8, 1, 1]) - u[7, 1, 1] + u[7, 1, 3] - u[7, 1, 1]*u[7, 1, 2]
                100.0(u[7, 1, 1] - 4.0u[8, 1, 1] + 2.0u[8, 2, 1] + u[9, 1, 1]) - u[8, 1, 1] + u[8, 1, 3] - u[8, 1, 1]*u[8, 1, 2]
                100.0(u[10, 1, 1] + u[8, 1, 1] - 4.0u[9, 1, 1] + 2.0u[9, 2, 1]) - u[9, 1, 1] + u[9, 1, 3] - u[9, 1, 1]*u[9, 1, 2]
 -u[10, 1, 1] + 100.0(-4.0u[10, 1, 1] + 2.0u[10, 2, 1] + u[11, 1, 1] + u[9, 1, 1]) + u[10, 1, 3] - u[10, 1, 1]*u[10, 1, 2]

Вывод, здесь измененный на месте du, является символьным представлением каждого вывода функции. Затем его можно использовать в функционале Symbolics. Например, сначала построим параллельную версию f:

fastf = eval(Symbolics.build_function(du,u,
            parallel=Symbolics.MultithreadedForm())[2])
#13 (generic function with 1 method)

Теперь вычислим разреженную функцию Якоби и скомпилируем быструю многопоточную версию:

jac = Symbolics.sparsejacobian(vec(du), vec(u))
row,col,val = findnz(jac)
scatter(row,col,legend=false,ms=1,c=:black)

fjac = eval(Symbolics.build_function(jac,u,
            parallel=Symbolics.MultithreadedForm())[2])
#15 (generic function with 1 method)

Это займет некоторое время, но результаты того стоят. Теперь зададим параболическое PDE, которое будет решаться с помощью DifferentialEquations.jl. Мы зададим простую версию и разреженную многопоточную версию:

using OrdinaryDiffEq
u0 = zeros(N, N, 3)
MyA = zeros(N, N);
AMx = zeros(N, N);
DA = zeros(N, N);
prob = ODEProblem(f, u0, (0.0, 10.0))
fastprob = ODEProblem(ODEFunction((du, u, p, t) -> fastf(du, u),
                                   jac = (du, u, p, t) -> fjac(du, u),
                                   jac_prototype = similar(jac, Float64)),
                                   u0, (0.0, 10.0))
ODEProblem with uType Array{Float64, 3} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: 32×32×3 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱  ⋮                        ⋮
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 2] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱  ⋮                        ⋮
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 3] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱  ⋮                        ⋮
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

Посмотрим на разницу во времени:

using BenchmarkTools
#@btime solve(prob, TRBDF2()); # 33,073 с (895404 выделения: 23,87 ГиБ)
#Предупреждение: следующее решение компилируется долго, но после этого работает очень быстро.
#@btime solve(fastprob, TRBDF2()); # 209,670 мс (8208 выделений: 109,25 МиБ)

Бум! Автоматическое 157-кратное ускорение, которое растет по мере увеличения размера задачи.