Обзор
В этом разделе описывается текущее устройство «компилятора» моделей Turing, который позволяет Turing выполнять различные виды байесовского вывода, не изменяя определение модели. «Компилятор» — это, по сути, просто макрос, который преобразует определение модели пользователя в функцию, создающую структуру Model, с которой может работать механизм диспетчеризации Julia и для которой компилятор Julia может успешно выполнять вывод типов с целью эффективного генерирования машинного кода.
В этом разделе будет использоваться следующая терминология:
-
D: наблюдаемые переменные данных, зависящие от апостериорных значений; -
P: переменные-параметры, распределяемые в соответствии с априорным распределением; будут также называться случайными переменными; -
Model: полностью определенная вероятностная модель с входными данными.
Макрос @model в Turing преобразует предоставленное пользователем определение функции так, что его можно использовать для создания экземпляра Model путем передачи наблюдаемых данных D.
Макрос @model выполняет следующие основные задачи:
-
анализ строк
~и.~, напримерy .~ Normal.(c*x, 1.0); -
определение того, относится ли переменная к данным
Dи (или) к параметрамP; -
обеспечение обработки отсутствующих переменных данных в
Dпри определенииModelи их интерпретация как переменных-параметров вP; -
обеспечение отслеживания случайных переменных с помощью структур данных
VarNameиVarInfo; -
изменение строк
~и.~с переменной вPв LHS на вызовtilde_assumeилиdot_tilde_assume; -
изменение строк
~и.~с переменной вDв LHS на вызовtilde_observeилиdot_tilde_observe; -
обеспечение стабильного по типу автоматического дифференцирования модели с помощью параметров типов.
Модель
model::Model представляет собой вызываемую структуру, из которой можно осуществлять выборку путем вызова
(model::Model)([rng, varinfo, sampler, context])
где rng — генератор случайных чисел (по умолчанию Random.default_rng()), varinfo — структура данных, в которой хранится информация о случайных переменных (по умолчанию DynamicPPL.VarInfo()), sampler — алгоритм выборки (по умолчанию DynamicPPL.SampleFromPrior()), а context — контекст выборки, который может, например, изменять способ наращивания логарифмической вероятности (по умолчанию DynamicPPL.DefaultContext()).
При выборе сбрасывается логарифмическая совместная вероятность varinfo и увеличивается счетчик вычисления sampler. Если context имеет значение LikelihoodContext, наращивается только логарифмическая вероятность. При значении DefaultContext наращивается логарифмическая вероятность P и D.
Структура Model содержит три внутренних поля: f, args и defaults. При вызове model::Model вызывается внутренняя функция model.f в виде model.f(rng, varinfo, sampler, context, model.args...) (в случае многопоточной выборки вместо varinfo в model.f передается потокобезопасная оболочка). Позиционные и именованные аргументы, переданные в пользовательскую функцию модели при создании модели, сохраняются как NamedTuple в model.args. Значения по умолчанию позиционных и именованных аргументов пользовательских функций моделей при их наличии сохраняются как NamedTuple в model.defaults. Они служат для создания экземпляров модели с различными аргументами с помощью строковых макросов logprob и prob.
Пример
Возьмем для примера следующую модель:
@model function gauss(x = missing, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
if x === missing
x = TV(undef, 3)
end
p = TV(undef, 2)
p[1] ~ InverseGamma(2, 3)
p[2] ~ Normal(0, 1.0)
@. x[1:2] ~ Normal(p[2], sqrt(p[1]))
x[3] ~ Normal()
y ~ Normal(p[2], sqrt(p[1]))
end
Представленный выше вызов макроса @model определяет функцию gauss с позиционными аргументами x, y и ::Type{TV}, преобразованную таким образом, что каждый ее вызов возвращает model::Model. Обратите внимание, что макрос @model изменяет только тело функции, а ее сигнатура остается без изменений. Можно также реализовывать модели с именованными аргументами, например:
@model function gauss(::Type{TV} = Vector{Float64}; x = missing, y = 1.0) where {TV<:AbstractVector}
...
end
Это позволяет создавать модель путем вызова gauss(; x = rand(3)).
Если у аргумента значение по умолчанию missing, он интерпретируется как случайная переменная. Для переменных, требующих инициализации из-за необходимости циклического перебора элементов или трансляции, таких как x выше, следует сделать следующее.
if x === missing
x = ...
end
Обратите внимание: так как gauss работает как обычная функция, вторым шагом можно также определить дополнительные операции диспетчеризации. Например, такое же поведение можно обеспечить следующим образом.
@model function gauss(x, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
p = TV(undef, 2)
...
end
function gauss(::Missing, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
return gauss(TV(undef, 3), y, TV)
end
Если переменная x выбирается целиком из распределения и не индексируется, например x ~ Normal(...) или x ~ MvNormal(...), инициализировать ее в блоке if не требуется.
Шаг 1. Разбиение определения модели
Сначала макрос @model разбивает пользовательское определение модели на части с помощью DynamicPPL.build_model_info. Эта функция возвращает словарь, состоящий из следующих элементов.
-
allargs_exprs: выражения позиционных и именованных аргументов без значений по умолчанию. -
allargs_syms: имена позиционных и именованных аргументов, например[:x, :y, :TV]выше. -
allargs_namedtuple: выражение, которое создает кортежNamedTupleпозиционных и именованных аргументов, например:x = x, y = y, TV = TVвыше. -
defaults_namedtuple: выражение, которое создает кортежNamedTupleпозиционных и именованных аргументов по умолчанию при их наличии, например:((x = missing, y = 1, TV = Vector{Float64}))выше. -
modeldef: словарь, возвращаемыйMacroTools.splitdefи содержащий имя, аргументы и тело функции из определения модели.
Шаг 2. Создание тела внутренней функции модели
На втором этапе DynamicPPL.generate_mainbody генерирует основную часть тела преобразуемой функции на основе предоставленного пользователем тела функции и ее аргументов без значений по умолчанию; при этом определяется, соответствует ли переменная наблюдению или случайной переменной. Для этого функция DynamicPPL.generate_tilde заменяет строки L ~ R в модели, а функция DynamicPPL.generate_dot_tilde — строки @. L ~ R и L .~ R.
Строка p[1] ~ InverseGamma(2, 3) из приведенного выше примера заменяется кодом наподобие следующего:
#= REPL[25]:6 =#
begin
var"##tmpright#323" = InverseGamma(2, 3)
var"##tmpright#323" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#325" = (DynamicPPL.VarName)(:p, ((1,),))
var"##inds#326" = ((1,),)
p[1] = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#323", var"##vn#325", var"##inds#326", _varinfo)
end
Здесь первая строка — это так называемый узел номера строки, который позволяет выдавать более информативные сообщения об ошибках, в которых точно указывается место в определении модели, где возникла проблема. Затем правая часть (RHS) оператора ~ присваивается переменной (с автоматически сгенерированным именем). При этом проверяется, содержит ли правая часть распределение или массив распределений. Если это не так, выдается ошибка. Далее извлекается краткое представление переменной с ее именем и индексом (или индексами). Наконец, выражение ~ заменяется вызовом DynamicPPL.tilde_assume, так как компилятор с помощью следующего эвристического механизма определил, что p[1] — это случайная переменная:
-
Если символа в левой части оператора
~(в данном случае:p) нет в числе аргументов модели (в данном случае(:x, :y, :T)), это случайная переменная. -
Если символ в левой части оператора
~(в данном случае:p) есть в числе аргументов модели, но имеет значениеmissing, это случайная переменная. -
Если значением в левой части оператора
~(в данном случаеp[1]) являетсяmissing, это случайная переменная. -
В противном случае переменная интерпретируется как наблюдение.
Функция DynamicPPL.tilde_assume при необходимости сама выполняет выборку случайной переменной и обновляет ее значение и накопленную логарифмическую совместную вероятность в объекте _varinfo. Если L ~ R — это наблюдение, DynamicPPL.tilde_observe вызывается с теми же аргументами, кроме генератора случайных чисел _rng (так как выборка наблюдений не производится).
Аналогичное преобразование осуществляется для выражений вида @. L ~ R и L .~ R. Например, @. x[1:2] ~ Normal(p[2], sqrt(p[1])) заменяется на
#= REPL[25]:8 =#
begin
var"##tmpright#331" = Normal.(p[2], sqrt.(p[1]))
var"##tmpright#331" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#333" = (DynamicPPL.VarName)(:x, ((1:2,),))
var"##inds#334" = ((1:2,),)
var"##isassumption#335" = begin
let var"##vn#336" = (DynamicPPL.VarName)(:x, ((1:2,),))
if !((DynamicPPL.inargnames)(var"##vn#336", _model)) || (DynamicPPL.inmissings)(var"##vn#336", _model)
true
else
x[1:2] === missing
end
end
end
if var"##isassumption#335"
x[1:2] .= (DynamicPPL.dot_tilde_assume)(_rng, _context, _sampler, var"##tmpright#331", x[1:2], var"##vn#333", var"##inds#334", _varinfo)
else
(DynamicPPL.dot_tilde_observe)(_context, _sampler, var"##tmpright#331", x[1:2], var"##vn#333", var"##inds#334", _varinfo)
end
end
Главным различием развернутого кода для L ~ R и @. L ~ R является то, что в первом случае не предполагается наличие определения L: это может быть новая переменная Julia в данной области. Во втором же случае предполагается, что L уже существует. Кроме того, вместо DynamicPPL.tilde_assume и DynamicPPL.tilde_observe вызываются функции DynamicPPL.dot_tilde_assume и DynamicPPL.dot_tilde_observe.
Шаг 3. Замена предоставленного пользователем тела функции
Наконец, предоставленное пользователем тело функции заменяется с помощью DynamicPPL.build_output. Эта функция использует MacroTools.combinedef для повторной сборки предоставленной пользователем функции с новым телом. В измененном теле функции создается анонимная функция, тело которой было сгенерировано на шаге 2 выше и которая имеет следующие аргументы:
-
генератор случайных чисел
_rng; -
модель
_model; -
структура данных
_varinfo; -
сэмплер
_sampler; -
контекст выборки
_context; -
все позиционные и именованные аргументы предоставленной пользователем функции модели в виде позиционных аргументов без значений по умолчанию. Наконец, в новом теле функции возвращается
model::Modelс этой анонимной функцией в качестве внутренней.
VarName
Для отслеживания случайных переменных в процессе выборки в Turing применяется структура VarName, которая выступает в роли идентификатора случайной переменной, генерируемого во время выполнения. VarName случайной переменной генерируется на основе выражения в левой части оператора ~, когда символ в левой части входит в набор P ненаблюдаемых случайных переменных. У каждого экземпляра VarName есть параметр типа sym. Это символ переменной Julia в модели, к которой относится случайная переменная. Например, x[1] ~ Normal() создает экземпляр VarName{:x} при условии, что x — ненаблюдаемая случайная переменная. У каждого VarName также есть поле indexing, в котором хранятся индексы, необходимые для доступа к случайной переменной из переменной Julia. Они указываются в sym в виде кортежа кортежей. Каждый элемент кортежа содержит индексы для одной операции индексирования (VarName также поддерживает иерархические массивы и индексирование диапазонов). Некоторые примеры:
-
x ~ Normal()генерируетVarName(:x, ()). -
x[1] ~ Normal()генерируетVarName(:x, 1,),. -
x[:,1] ~ MvNormal(zeros(2), I)генерируетVarName(:x, Colon(), 1),. -
x[:,1][1+1] ~ Normal()генерируетVarName(:x, Colon(), 1), (2,).
Самый простой способ вручную создать VarName — воспользоваться в выражении индексирования макросом @varname, который берет значение sym из фактического имени переменной и надлежащим образом помещает значения индексов в конструктор.
VarInfo
Обзор
VarInfo — это структура данных в Turing, которая упрощает отслеживание случайных переменных и некоторых их метаданных, необходимых для выборки. Например, в VarInfo хранится распределение каждой случайной переменной, так как, например, при выборке с использованием HMC необходимо знать, поддерживается ли оно. Случайные переменные, распределения которых имеют ограниченную поддержку, преобразуются с помощью биектора из пакета Bijectors.jl так, что выборка происходит в неограниченном пространстве. Разным сэмплерам требуются разные метаданные случайных переменных.
VarInfo имеет в Turing следующее определение:
struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo
metadata::Tmeta
logp::Base.RefValue{Tlogp}
num_produce::Base.RefValue{Int}
end
В зависимости от типа metadata объект VarInfo имеет псевдоним UntypedVarInfo или TypedVarInfo. metadata может быть либо подтипом типа объединения Metadata, либо кортежем NamedTuple нескольких таких подтипов. Допустим, vi — это экземпляр VarInfo. Если vi isa VarInfo{<:Metadata}, он будет называться UntypedVarInfo. Если vi isa VarInfo{<:NamedTuple}, то vi.metadata будет кортежем NamedTuple, в котором каждый символ из P сопоставляется с экземпляром Metadata. В таком случае vi будет называться TypedVarInfo. В число других полей VarInfo входит поле logp, которое служит для накопления логарифмической вероятности или логарифмической плотности вероятности переменных в P и D. num_produce отслеживает то, сколько наблюдений уже было сделано в модели. Значение увеличивается при выполнении оператора ~, если символ в левой части входит в D.
Metadata
В структуре Metadata хранятся некоторые метаданные случайных переменных, выборка которых производится. Это помогает запрашивать определенную информацию о переменной, а именно ее распределение, используемые для выборки сэмплеры, ее значение и то, преобразуется ли это значение в вещественное пространство. Допустим, md — это экземпляр Metadata:
-
md.vns— это вектор всех экземпляровVarName. Допустим,vn— это произвольный элементmd.vns. -
md.idcs— это словарь, в котором каждый экземплярVarNameсопоставляется с индексом в
md.vns, md.ranges, md.dists, md.orders и md.flags.
-
md.vns[md.idcs[vn]] == vn. -
md.dists[md.idcs[vn]]— это распределениеvn. -
md.gids[md.idcs[vn]]— это набор алгоритмов, используемых для выборкиvn. Применяется в
процессе выборки Гиббса.
-
md.orders[md.idcs[vn]]— это количество операторовobserveдо выборкиvn. -
md.ranges[md.idcs[vn]]— это диапазон индексовvnвmd.vals. -
md.vals[md.ranges[md.idcs[vn]]]— это линеаризованный вектор значений, соответствующихvn. -
md.flags— это словарь флагов true/false.md.flags[flag][md.idcs[vn]]— это
значение flag, соответствующее vn.
Обратите внимание: для обеспечения стабильности md::Metadata по типу все md.vns должны иметь один и тот же символ и тип распределения. Однако отдельная переменная Julia, например x, может представлять собой матрицу или иерархический массив, выборка которых производится по частям, например x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I). Символ x может по-прежнему управляться одним объектом md::Metadata без ущерба для стабильности по типу, так как все распределения в правой части оператора ~ одного и того же типа.
Однако в моделях Turing такого ограничения быть не может, поэтому необходимо использовать нестабильный тип Metadata, если один экземпляр Metadata должен применяться для всей модели. Для этого служит UntypedVarInfo. Нестабильный тип Metadata по-прежнему будет работать, но с более низкой производительностью.
Чтобы найти баланс между гибкостью и производительностью при создании экземпляра spl::Sampler, модель сначала запускается с выборкой параметров в P из априорных распределений с помощью UntypedVarInfo, то есть для всех переменных используется нестабильный по типу объект Metadata. После определения всех символов и типов распределений создается vi::TypedVarInfo, где vi.metadata — это кортеж NamedTuple, в котором каждый символ из P сопоставляется со специальным экземпляром Metadata. Поэтому при условии, что выборка для каждого символа в P производится только из одного типа распределений, vi::TypedVarInfo будет иметь полностью конкретно типизированные поля, благодаря чему в Julia достигается максимальная производительность.