Обзор
В этом разделе описывается текущее устройство «компилятора» моделей Turing, который позволяет Turing выполнять различные виды байесовского вывода, не изменяя определение модели. «Компилятор» — это, по сути, просто макрос, который преобразует определение модели пользователя в функцию, создающую структуру Model
, с которой может работать механизм диспетчеризации Julia и для которой компилятор Julia может успешно выполнять вывод типов с целью эффективного генерирования машинного кода.
В этом разделе будет использоваться следующая терминология:
-
D
: наблюдаемые переменные данных, зависящие от апостериорных значений; -
P
: переменные-параметры, распределяемые в соответствии с априорным распределением; будут также называться случайными переменными; -
Model
: полностью определенная вероятностная модель с входными данными.
Макрос @model
в Turing
преобразует предоставленное пользователем определение функции так, что его можно использовать для создания экземпляра Model
путем передачи наблюдаемых данных D
.
Макрос @model
выполняет следующие основные задачи:
-
анализ строк
~
и.~
, напримерy .~ Normal.(c*x, 1.0)
; -
определение того, относится ли переменная к данным
D
и (или) к параметрамP
; -
обеспечение обработки отсутствующих переменных данных в
D
при определенииModel
и их интерпретация как переменных-параметров вP
; -
обеспечение отслеживания случайных переменных с помощью структур данных
VarName
иVarInfo
; -
изменение строк
~
и.~
с переменной вP
в LHS на вызовtilde_assume
илиdot_tilde_assume
; -
изменение строк
~
и.~
с переменной вD
в LHS на вызовtilde_observe
илиdot_tilde_observe
; -
обеспечение стабильного по типу автоматического дифференцирования модели с помощью параметров типов.
Модель
model::Model
представляет собой вызываемую структуру, из которой можно осуществлять выборку путем вызова
(model::Model)([rng, varinfo, sampler, context])
где rng
— генератор случайных чисел (по умолчанию Random.default_rng()
), varinfo
— структура данных, в которой хранится информация о случайных переменных (по умолчанию DynamicPPL.VarInfo()
), sampler
— алгоритм выборки (по умолчанию DynamicPPL.SampleFromPrior()
), а context
— контекст выборки, который может, например, изменять способ наращивания логарифмической вероятности (по умолчанию DynamicPPL.DefaultContext()
).
При выборе сбрасывается логарифмическая совместная вероятность varinfo
и увеличивается счетчик вычисления sampler
. Если context
имеет значение LikelihoodContext
, наращивается только логарифмическая вероятность. При значении DefaultContext
наращивается логарифмическая вероятность P
и D
.
Структура Model
содержит три внутренних поля: f
, args
и defaults
. При вызове model::Model
вызывается внутренняя функция model.f
в виде model.f(rng, varinfo, sampler, context, model.args...)
(в случае многопоточной выборки вместо varinfo
в model.f
передается потокобезопасная оболочка). Позиционные и именованные аргументы, переданные в пользовательскую функцию модели при создании модели, сохраняются как NamedTuple
в model.args
. Значения по умолчанию позиционных и именованных аргументов пользовательских функций моделей при их наличии сохраняются как NamedTuple
в model.defaults
. Они служат для создания экземпляров модели с различными аргументами с помощью строковых макросов logprob
и prob
.
Пример
Возьмем для примера следующую модель:
@model function gauss(x = missing, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
if x === missing
x = TV(undef, 3)
end
p = TV(undef, 2)
p[1] ~ InverseGamma(2, 3)
p[2] ~ Normal(0, 1.0)
@. x[1:2] ~ Normal(p[2], sqrt(p[1]))
x[3] ~ Normal()
y ~ Normal(p[2], sqrt(p[1]))
end
Представленный выше вызов макроса @model
определяет функцию gauss
с позиционными аргументами x
, y
и ::Type{TV}
, преобразованную таким образом, что каждый ее вызов возвращает model::Model
. Обратите внимание, что макрос @model
изменяет только тело функции, а ее сигнатура остается без изменений. Можно также реализовывать модели с именованными аргументами, например:
@model function gauss(::Type{TV} = Vector{Float64}; x = missing, y = 1.0) where {TV<:AbstractVector}
...
end
Это позволяет создавать модель путем вызова gauss(; x = rand(3))
.
Если у аргумента значение по умолчанию missing
, он интерпретируется как случайная переменная. Для переменных, требующих инициализации из-за необходимости циклического перебора элементов или трансляции, таких как x
выше, следует сделать следующее.
if x === missing
x = ...
end
Обратите внимание: так как gauss
работает как обычная функция, вторым шагом можно также определить дополнительные операции диспетчеризации. Например, такое же поведение можно обеспечить следующим образом.
@model function gauss(x, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
p = TV(undef, 2)
...
end
function gauss(::Missing, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
return gauss(TV(undef, 3), y, TV)
end
Если переменная x
выбирается целиком из распределения и не индексируется, например x ~ Normal(...)
или x ~ MvNormal(...)
, инициализировать ее в блоке if
не требуется.
Шаг 1. Разбиение определения модели
Сначала макрос @model
разбивает пользовательское определение модели на части с помощью DynamicPPL.build_model_info
. Эта функция возвращает словарь, состоящий из следующих элементов.
-
allargs_exprs
: выражения позиционных и именованных аргументов без значений по умолчанию. -
allargs_syms
: имена позиционных и именованных аргументов, например[:x, :y, :TV]
выше. -
allargs_namedtuple
: выражение, которое создает кортежNamedTuple
позиционных и именованных аргументов, например:x = x, y = y, TV = TV
выше. -
defaults_namedtuple
: выражение, которое создает кортежNamedTuple
позиционных и именованных аргументов по умолчанию при их наличии, например:((x = missing, y = 1, TV = Vector{Float64}))
выше. -
modeldef
: словарь, возвращаемыйMacroTools.splitdef
и содержащий имя, аргументы и тело функции из определения модели.
Шаг 2. Создание тела внутренней функции модели
На втором этапе DynamicPPL.generate_mainbody
генерирует основную часть тела преобразуемой функции на основе предоставленного пользователем тела функции и ее аргументов без значений по умолчанию; при этом определяется, соответствует ли переменная наблюдению или случайной переменной. Для этого функция DynamicPPL.generate_tilde
заменяет строки L ~ R
в модели, а функция DynamicPPL.generate_dot_tilde
— строки @. L ~ R
и L .~ R
.
Строка p[1] ~ InverseGamma(2, 3)
из приведенного выше примера заменяется кодом наподобие следующего:
#= REPL[25]:6 =#
begin
var"##tmpright#323" = InverseGamma(2, 3)
var"##tmpright#323" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#325" = (DynamicPPL.VarName)(:p, ((1,),))
var"##inds#326" = ((1,),)
p[1] = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#323", var"##vn#325", var"##inds#326", _varinfo)
end
Здесь первая строка — это так называемый узел номера строки, который позволяет выдавать более информативные сообщения об ошибках, в которых точно указывается место в определении модели, где возникла проблема. Затем правая часть (RHS) оператора ~
присваивается переменной (с автоматически сгенерированным именем). При этом проверяется, содержит ли правая часть распределение или массив распределений. Если это не так, выдается ошибка. Далее извлекается краткое представление переменной с ее именем и индексом (или индексами). Наконец, выражение ~
заменяется вызовом DynamicPPL.tilde_assume
, так как компилятор с помощью следующего эвристического механизма определил, что p[1]
— это случайная переменная:
-
Если символа в левой части оператора
~
(в данном случае:p
) нет в числе аргументов модели (в данном случае(:x, :y, :T)
), это случайная переменная. -
Если символ в левой части оператора
~
(в данном случае:p
) есть в числе аргументов модели, но имеет значениеmissing
, это случайная переменная. -
Если значением в левой части оператора
~
(в данном случаеp[1]
) являетсяmissing
, это случайная переменная. -
В противном случае переменная интерпретируется как наблюдение.
Функция DynamicPPL.tilde_assume
при необходимости сама выполняет выборку случайной переменной и обновляет ее значение и накопленную логарифмическую совместную вероятность в объекте _varinfo
. Если L ~ R
— это наблюдение, DynamicPPL.tilde_observe
вызывается с теми же аргументами, кроме генератора случайных чисел _rng
(так как выборка наблюдений не производится).
Аналогичное преобразование осуществляется для выражений вида @. L ~ R
и L .~ R
. Например, @. x[1:2] ~ Normal(p[2], sqrt(p[1]))
заменяется на
#= REPL[25]:8 =#
begin
var"##tmpright#331" = Normal.(p[2], sqrt.(p[1]))
var"##tmpright#331" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#333" = (DynamicPPL.VarName)(:x, ((1:2,),))
var"##inds#334" = ((1:2,),)
var"##isassumption#335" = begin
let var"##vn#336" = (DynamicPPL.VarName)(:x, ((1:2,),))
if !((DynamicPPL.inargnames)(var"##vn#336", _model)) || (DynamicPPL.inmissings)(var"##vn#336", _model)
true
else
x[1:2] === missing
end
end
end
if var"##isassumption#335"
x[1:2] .= (DynamicPPL.dot_tilde_assume)(_rng, _context, _sampler, var"##tmpright#331", x[1:2], var"##vn#333", var"##inds#334", _varinfo)
else
(DynamicPPL.dot_tilde_observe)(_context, _sampler, var"##tmpright#331", x[1:2], var"##vn#333", var"##inds#334", _varinfo)
end
end
Главным различием развернутого кода для L ~ R
и @. L ~ R
является то, что в первом случае не предполагается наличие определения L
: это может быть новая переменная Julia в данной области. Во втором же случае предполагается, что L
уже существует. Кроме того, вместо DynamicPPL.tilde_assume
и DynamicPPL.tilde_observe
вызываются функции DynamicPPL.dot_tilde_assume
и DynamicPPL.dot_tilde_observe
.
Шаг 3. Замена предоставленного пользователем тела функции
Наконец, предоставленное пользователем тело функции заменяется с помощью DynamicPPL.build_output
. Эта функция использует MacroTools.combinedef
для повторной сборки предоставленной пользователем функции с новым телом. В измененном теле функции создается анонимная функция, тело которой было сгенерировано на шаге 2 выше и которая имеет следующие аргументы:
-
генератор случайных чисел
_rng
; -
модель
_model
; -
структура данных
_varinfo
; -
сэмплер
_sampler
; -
контекст выборки
_context
; -
все позиционные и именованные аргументы предоставленной пользователем функции модели в виде позиционных аргументов без значений по умолчанию. Наконец, в новом теле функции возвращается
model::Model
с этой анонимной функцией в качестве внутренней.
VarName
Для отслеживания случайных переменных в процессе выборки в Turing
применяется структура VarName
, которая выступает в роли идентификатора случайной переменной, генерируемого во время выполнения. VarName
случайной переменной генерируется на основе выражения в левой части оператора ~
, когда символ в левой части входит в набор P
ненаблюдаемых случайных переменных. У каждого экземпляра VarName
есть параметр типа sym
. Это символ переменной Julia в модели, к которой относится случайная переменная. Например, x[1] ~ Normal()
создает экземпляр VarName{:x}
при условии, что x
— ненаблюдаемая случайная переменная. У каждого VarName
также есть поле indexing
, в котором хранятся индексы, необходимые для доступа к случайной переменной из переменной Julia. Они указываются в sym
в виде кортежа кортежей. Каждый элемент кортежа содержит индексы для одной операции индексирования (VarName
также поддерживает иерархические массивы и индексирование диапазонов). Некоторые примеры:
-
x ~ Normal()
генерируетVarName(:x, ())
. -
x[1] ~ Normal()
генерируетVarName(:x, 1,),
. -
x[:,1] ~ MvNormal(zeros(2), I)
генерируетVarName(:x, Colon(), 1),
. -
x[:,1][1+1] ~ Normal()
генерируетVarName(:x, Colon(), 1), (2,)
.
Самый простой способ вручную создать VarName
— воспользоваться в выражении индексирования макросом @varname
, который берет значение sym
из фактического имени переменной и надлежащим образом помещает значения индексов в конструктор.
VarInfo
Обзор
VarInfo
— это структура данных в Turing
, которая упрощает отслеживание случайных переменных и некоторых их метаданных, необходимых для выборки. Например, в VarInfo
хранится распределение каждой случайной переменной, так как, например, при выборке с использованием HMC необходимо знать, поддерживается ли оно. Случайные переменные, распределения которых имеют ограниченную поддержку, преобразуются с помощью биектора из пакета Bijectors.jl так, что выборка происходит в неограниченном пространстве. Разным сэмплерам требуются разные метаданные случайных переменных.
VarInfo
имеет в Turing
следующее определение:
struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo metadata::Tmeta logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end
В зависимости от типа metadata
объект VarInfo
имеет псевдоним UntypedVarInfo
или TypedVarInfo
. metadata
может быть либо подтипом типа объединения Metadata
, либо кортежем NamedTuple
нескольких таких подтипов. Допустим, vi
— это экземпляр VarInfo
. Если vi isa VarInfo{<:Metadata}
, он будет называться UntypedVarInfo
. Если vi isa VarInfo{<:NamedTuple}
, то vi.metadata
будет кортежем NamedTuple
, в котором каждый символ из P
сопоставляется с экземпляром Metadata
. В таком случае vi
будет называться TypedVarInfo
. В число других полей VarInfo
входит поле logp
, которое служит для накопления логарифмической вероятности или логарифмической плотности вероятности переменных в P
и D
. num_produce
отслеживает то, сколько наблюдений уже было сделано в модели. Значение увеличивается при выполнении оператора ~
, если символ в левой части входит в D
.
Metadata
В структуре Metadata
хранятся некоторые метаданные случайных переменных, выборка которых производится. Это помогает запрашивать определенную информацию о переменной, а именно ее распределение, используемые для выборки сэмплеры, ее значение и то, преобразуется ли это значение в вещественное пространство. Допустим, md
— это экземпляр Metadata
:
-
md.vns
— это вектор всех экземпляровVarName
. Допустим,vn
— это произвольный элементmd.vns
. -
md.idcs
— это словарь, в котором каждый экземплярVarName
сопоставляется с индексом в
md.vns
, md.ranges
, md.dists
, md.orders
и md.flags
.
-
md.vns[md.idcs[vn]] == vn
. -
md.dists[md.idcs[vn]]
— это распределениеvn
. -
md.gids[md.idcs[vn]]
— это набор алгоритмов, используемых для выборкиvn
. Применяется в
процессе выборки Гиббса.
-
md.orders[md.idcs[vn]]
— это количество операторовobserve
до выборкиvn
. -
md.ranges[md.idcs[vn]]
— это диапазон индексовvn
вmd.vals
. -
md.vals[md.ranges[md.idcs[vn]]]
— это линеаризованный вектор значений, соответствующихvn
. -
md.flags
— это словарь флагов true/false.md.flags[flag][md.idcs[vn]]
— это
значение flag
, соответствующее vn
.
Обратите внимание: для обеспечения стабильности md::Metadata
по типу все md.vns
должны иметь один и тот же символ и тип распределения. Однако отдельная переменная Julia, например x
, может представлять собой матрицу или иерархический массив, выборка которых производится по частям, например x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)
. Символ x
может по-прежнему управляться одним объектом md::Metadata
без ущерба для стабильности по типу, так как все распределения в правой части оператора ~
одного и того же типа.
Однако в моделях Turing
такого ограничения быть не может, поэтому необходимо использовать нестабильный тип Metadata
, если один экземпляр Metadata
должен применяться для всей модели. Для этого служит UntypedVarInfo
. Нестабильный тип Metadata
по-прежнему будет работать, но с более низкой производительностью.
Чтобы найти баланс между гибкостью и производительностью при создании экземпляра spl::Sampler
, модель сначала запускается с выборкой параметров в P
из априорных распределений с помощью UntypedVarInfo
, то есть для всех переменных используется нестабильный по типу объект Metadata
. После определения всех символов и типов распределений создается vi::TypedVarInfo
, где vi.metadata
— это кортеж NamedTuple
, в котором каждый символ из P
сопоставляется со специальным экземпляром Metadata
. Поэтому при условии, что выборка для каждого символа в P
производится только из одного типа распределений, vi::TypedVarInfo
будет иметь полностью конкретно типизированные поля, благодаря чему в Julia достигается максимальная производительность.