Документация Engee

Обзор

В этом разделе описывается текущее устройство «компилятора» моделей Turing, который позволяет Turing выполнять различные виды байесовского вывода, не изменяя определение модели. «Компилятор» — это, по сути, просто макрос, который преобразует определение модели пользователя в функцию, создающую структуру Model, с которой может работать механизм диспетчеризации Julia и для которой компилятор Julia может успешно выполнять вывод типов с целью эффективного генерирования машинного кода.

В этом разделе будет использоваться следующая терминология:

  • D: наблюдаемые переменные данных, зависящие от апостериорных значений;

  • P: переменные-параметры, распределяемые в соответствии с априорным распределением; будут также называться случайными переменными;

  • Model: полностью определенная вероятностная модель с входными данными.

Макрос @model в Turing преобразует предоставленное пользователем определение функции так, что его можно использовать для создания экземпляра Model путем передачи наблюдаемых данных D.

Макрос @model выполняет следующие основные задачи:

  1. анализ строк ~ и .~, например y .~ Normal.(c*x, 1.0);

  2. определение того, относится ли переменная к данным D и (или) к параметрам P;

  3. обеспечение обработки отсутствующих переменных данных в D при определении Model и их интерпретация как переменных-параметров в P;

  4. обеспечение отслеживания случайных переменных с помощью структур данных VarName и VarInfo;

  5. изменение строк ~ и .~ с переменной в P в LHS на вызов tilde_assume или dot_tilde_assume;

  6. изменение строк ~ и .~ с переменной в D в LHS на вызов tilde_observe или dot_tilde_observe;

  7. обеспечение стабильного по типу автоматического дифференцирования модели с помощью параметров типов.

Модель

model::Model представляет собой вызываемую структуру, из которой можно осуществлять выборку путем вызова

(model::Model)([rng, varinfo, sampler, context])

где rng — генератор случайных чисел (по умолчанию Random.default_rng()), varinfo — структура данных, в которой хранится информация о случайных переменных (по умолчанию DynamicPPL.VarInfo()), sampler — алгоритм выборки (по умолчанию DynamicPPL.SampleFromPrior()), а context — контекст выборки, который может, например, изменять способ наращивания логарифмической вероятности (по умолчанию DynamicPPL.DefaultContext()).

При выборе сбрасывается логарифмическая совместная вероятность varinfo и увеличивается счетчик вычисления sampler. Если context имеет значение LikelihoodContext, наращивается только логарифмическая вероятность. При значении DefaultContext наращивается логарифмическая вероятность P и D.

Структура Model содержит три внутренних поля: f, args и defaults. При вызове model::Model вызывается внутренняя функция model.f в виде model.f(rng, varinfo, sampler, context, model.args...) (в случае многопоточной выборки вместо varinfo в model.f передается потокобезопасная оболочка). Позиционные и именованные аргументы, переданные в пользовательскую функцию модели при создании модели, сохраняются как NamedTuple в model.args. Значения по умолчанию позиционных и именованных аргументов пользовательских функций моделей при их наличии сохраняются как NamedTuple в model.defaults. Они служат для создания экземпляров модели с различными аргументами с помощью строковых макросов logprob и prob.

Пример

Возьмем для примера следующую модель:

@model function gauss(x = missing, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
    if x === missing
        x = TV(undef, 3)
    end
    p = TV(undef, 2)
    p[1] ~ InverseGamma(2, 3)
    p[2] ~ Normal(0, 1.0)
    @. x[1:2] ~ Normal(p[2], sqrt(p[1]))
    x[3] ~ Normal()
    y ~ Normal(p[2], sqrt(p[1]))
end

Представленный выше вызов макроса @model определяет функцию gauss с позиционными аргументами x, y и ::Type{TV}, преобразованную таким образом, что каждый ее вызов возвращает model::Model. Обратите внимание, что макрос @model изменяет только тело функции, а ее сигнатура остается без изменений. Можно также реализовывать модели с именованными аргументами, например:

@model function gauss(::Type{TV} = Vector{Float64}; x = missing, y = 1.0) where {TV<:AbstractVector}
    ...
end

Это позволяет создавать модель путем вызова gauss(; x = rand(3)).

Если у аргумента значение по умолчанию missing, он интерпретируется как случайная переменная. Для переменных, требующих инициализации из-за необходимости циклического перебора элементов или трансляции, таких как x выше, следует сделать следующее.

if x === missing
    x = ...
end

Обратите внимание: так как gauss работает как обычная функция, вторым шагом можно также определить дополнительные операции диспетчеризации. Например, такое же поведение можно обеспечить следующим образом.

@model function gauss(x, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
    p = TV(undef, 2)
    ...
end

function gauss(::Missing, y = 1.0, ::Type{TV} = Vector{Float64}) where {TV<:AbstractVector}
    return gauss(TV(undef, 3), y, TV)
end

Если переменная x выбирается целиком из распределения и не индексируется, например x ~ Normal(...) или x ~ MvNormal(...), инициализировать ее в блоке if не требуется.

Шаг 1. Разбиение определения модели

Сначала макрос @model разбивает пользовательское определение модели на части с помощью DynamicPPL.build_model_info. Эта функция возвращает словарь, состоящий из следующих элементов.

  • allargs_exprs: выражения позиционных и именованных аргументов без значений по умолчанию.

  • allargs_syms: имена позиционных и именованных аргументов, например [:x, :y, :TV] выше.

  • allargs_namedtuple: выражение, которое создает кортеж NamedTuple позиционных и именованных аргументов, например :x = x, y = y, TV = TV выше.

  • defaults_namedtuple: выражение, которое создает кортеж NamedTuple позиционных и именованных аргументов по умолчанию при их наличии, например :((x = missing, y = 1, TV = Vector{Float64})) выше.

  • modeldef: словарь, возвращаемый MacroTools.splitdef и содержащий имя, аргументы и тело функции из определения модели.

Шаг 2. Создание тела внутренней функции модели

На втором этапе DynamicPPL.generate_mainbody генерирует основную часть тела преобразуемой функции на основе предоставленного пользователем тела функции и ее аргументов без значений по умолчанию; при этом определяется, соответствует ли переменная наблюдению или случайной переменной. Для этого функция DynamicPPL.generate_tilde заменяет строки L ~ R в модели, а функция DynamicPPL.generate_dot_tilde — строки @. L ~ R и L .~ R.

Строка p[1] ~ InverseGamma(2, 3) из приведенного выше примера заменяется кодом наподобие следующего:

#= REPL[25]:6 =#
begin
    var"##tmpright#323" = InverseGamma(2, 3)
    var"##tmpright#323" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
    var"##vn#325" = (DynamicPPL.VarName)(:p, ((1,),))
    var"##inds#326" = ((1,),)
    p[1] = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#323", var"##vn#325", var"##inds#326", _varinfo)
end

Здесь первая строка — это так называемый узел номера строки, который позволяет выдавать более информативные сообщения об ошибках, в которых точно указывается место в определении модели, где возникла проблема. Затем правая часть (RHS) оператора ~ присваивается переменной (с автоматически сгенерированным именем). При этом проверяется, содержит ли правая часть распределение или массив распределений. Если это не так, выдается ошибка. Далее извлекается краткое представление переменной с ее именем и индексом (или индексами). Наконец, выражение ~ заменяется вызовом DynamicPPL.tilde_assume, так как компилятор с помощью следующего эвристического механизма определил, что p[1] — это случайная переменная:

  1. Если символа в левой части оператора ~ (в данном случае :p) нет в числе аргументов модели (в данном случае (:x, :y, :T)), это случайная переменная.

  2. Если символ в левой части оператора ~ (в данном случае :p) есть в числе аргументов модели, но имеет значение missing, это случайная переменная.

  3. Если значением в левой части оператора ~ (в данном случае p[1]) является missing, это случайная переменная.

  4. В противном случае переменная интерпретируется как наблюдение.

Функция DynamicPPL.tilde_assume при необходимости сама выполняет выборку случайной переменной и обновляет ее значение и накопленную логарифмическую совместную вероятность в объекте _varinfo. Если L ~ R — это наблюдение, DynamicPPL.tilde_observe вызывается с теми же аргументами, кроме генератора случайных чисел _rng (так как выборка наблюдений не производится).

Аналогичное преобразование осуществляется для выражений вида @. L ~ R и L .~ R. Например, @. x[1:2] ~ Normal(p[2], sqrt(p[1])) заменяется на

#= REPL[25]:8 =#
begin
    var"##tmpright#331" = Normal.(p[2], sqrt.(p[1]))
    var"##tmpright#331" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
    var"##vn#333" = (DynamicPPL.VarName)(:x, ((1:2,),))
    var"##inds#334" = ((1:2,),)
    var"##isassumption#335" = begin
        let var"##vn#336" = (DynamicPPL.VarName)(:x, ((1:2,),))
            if !((DynamicPPL.inargnames)(var"##vn#336", _model)) || (DynamicPPL.inmissings)(var"##vn#336", _model)
                true
            else
                x[1:2] === missing
            end
        end
    end
    if var"##isassumption#335"
        x[1:2] .= (DynamicPPL.dot_tilde_assume)(_rng, _context, _sampler, var"##tmpright#331", x[1:2], var"##vn#333", var"##inds#334", _varinfo)
    else
        (DynamicPPL.dot_tilde_observe)(_context, _sampler, var"##tmpright#331", x[1:2], var"##vn#333", var"##inds#334", _varinfo)
    end
end

Главным различием развернутого кода для L ~ R и @. L ~ R является то, что в первом случае не предполагается наличие определения L: это может быть новая переменная Julia в данной области. Во втором же случае предполагается, что L уже существует. Кроме того, вместо DynamicPPL.tilde_assume и DynamicPPL.tilde_observe вызываются функции DynamicPPL.dot_tilde_assume и DynamicPPL.dot_tilde_observe.

Шаг 3. Замена предоставленного пользователем тела функции

Наконец, предоставленное пользователем тело функции заменяется с помощью DynamicPPL.build_output. Эта функция использует MacroTools.combinedef для повторной сборки предоставленной пользователем функции с новым телом. В измененном теле функции создается анонимная функция, тело которой было сгенерировано на шаге 2 выше и которая имеет следующие аргументы:

  • генератор случайных чисел _rng;

  • модель _model;

  • структура данных _varinfo;

  • сэмплер _sampler;

  • контекст выборки _context;

  • все позиционные и именованные аргументы предоставленной пользователем функции модели в виде позиционных аргументов без значений по умолчанию. Наконец, в новом теле функции возвращается model::Model с этой анонимной функцией в качестве внутренней.

VarName

Для отслеживания случайных переменных в процессе выборки в Turing применяется структура VarName, которая выступает в роли идентификатора случайной переменной, генерируемого во время выполнения. VarName случайной переменной генерируется на основе выражения в левой части оператора ~, когда символ в левой части входит в набор P ненаблюдаемых случайных переменных. У каждого экземпляра VarName есть параметр типа sym. Это символ переменной Julia в модели, к которой относится случайная переменная. Например, x[1] ~ Normal() создает экземпляр VarName{:x} при условии, что x — ненаблюдаемая случайная переменная. У каждого VarName также есть поле indexing, в котором хранятся индексы, необходимые для доступа к случайной переменной из переменной Julia. Они указываются в sym в виде кортежа кортежей. Каждый элемент кортежа содержит индексы для одной операции индексирования (VarName также поддерживает иерархические массивы и индексирование диапазонов). Некоторые примеры:

  • x ~ Normal() генерирует VarName(:x, ()).

  • x[1] ~ Normal() генерирует VarName(:x, 1,),.

  • x[:,1] ~ MvNormal(zeros(2), I) генерирует VarName(:x, Colon(), 1),.

  • x[:,1][1+1] ~ Normal() генерирует VarName(:x, Colon(), 1), (2,).

Самый простой способ вручную создать VarName — воспользоваться в выражении индексирования макросом @varname, который берет значение sym из фактического имени переменной и надлежащим образом помещает значения индексов в конструктор.

VarInfo

Обзор

VarInfo — это структура данных в Turing, которая упрощает отслеживание случайных переменных и некоторых их метаданных, необходимых для выборки. Например, в VarInfo хранится распределение каждой случайной переменной, так как, например, при выборке с использованием HMC необходимо знать, поддерживается ли оно. Случайные переменные, распределения которых имеют ограниченную поддержку, преобразуются с помощью биектора из пакета Bijectors.jl так, что выборка происходит в неограниченном пространстве. Разным сэмплерам требуются разные метаданные случайных переменных.

VarInfo имеет в Turing следующее определение:

struct VarInfo{Tmeta, Tlogp} <: AbstractVarInfo
    metadata::Tmeta
    logp::Base.RefValue{Tlogp}
    num_produce::Base.RefValue{Int}
end

В зависимости от типа metadata объект VarInfo имеет псевдоним UntypedVarInfo или TypedVarInfo. metadata может быть либо подтипом типа объединения Metadata, либо кортежем NamedTuple нескольких таких подтипов. Допустим, vi — это экземпляр VarInfo. Если vi isa VarInfo{<:Metadata}, он будет называться UntypedVarInfo. Если vi isa VarInfo{<:NamedTuple}, то vi.metadata будет кортежем NamedTuple, в котором каждый символ из P сопоставляется с экземпляром Metadata. В таком случае vi будет называться TypedVarInfo. В число других полей VarInfo входит поле logp, которое служит для накопления логарифмической вероятности или логарифмической плотности вероятности переменных в P и D. num_produce отслеживает то, сколько наблюдений уже было сделано в модели. Значение увеличивается при выполнении оператора ~, если символ в левой части входит в D.

Metadata

В структуре Metadata хранятся некоторые метаданные случайных переменных, выборка которых производится. Это помогает запрашивать определенную информацию о переменной, а именно ее распределение, используемые для выборки сэмплеры, ее значение и то, преобразуется ли это значение в вещественное пространство. Допустим, md — это экземпляр Metadata:

  • md.vns — это вектор всех экземпляров VarName. Допустим, vn — это произвольный элемент md.vns.

  • md.idcs — это словарь, в котором каждый экземпляр VarName сопоставляется с индексом в

md.vns, md.ranges, md.dists, md.orders и md.flags.

  • md.vns[md.idcs[vn]] == vn.

  • md.dists[md.idcs[vn]] — это распределение vn.

  • md.gids[md.idcs[vn]] — это набор алгоритмов, используемых для выборки vn. Применяется в

процессе выборки Гиббса.

  • md.orders[md.idcs[vn]] — это количество операторов observe до выборки vn.

  • md.ranges[md.idcs[vn]] — это диапазон индексов vn в md.vals.

  • md.vals[md.ranges[md.idcs[vn]]] — это линеаризованный вектор значений, соответствующих vn.

  • md.flags — это словарь флагов true/false. md.flags[flag][md.idcs[vn]] — это

значение flag, соответствующее vn.

Обратите внимание: для обеспечения стабильности md::Metadata по типу все md.vns должны иметь один и тот же символ и тип распределения. Однако отдельная переменная Julia, например x, может представлять собой матрицу или иерархический массив, выборка которых производится по частям, например x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I). Символ x может по-прежнему управляться одним объектом md::Metadata без ущерба для стабильности по типу, так как все распределения в правой части оператора ~ одного и того же типа.

Однако в моделях Turing такого ограничения быть не может, поэтому необходимо использовать нестабильный тип Metadata, если один экземпляр Metadata должен применяться для всей модели. Для этого служит UntypedVarInfo. Нестабильный тип Metadata по-прежнему будет работать, но с более низкой производительностью.

Чтобы найти баланс между гибкостью и производительностью при создании экземпляра spl::Sampler, модель сначала запускается с выборкой параметров в P из априорных распределений с помощью UntypedVarInfo, то есть для всех переменных используется нестабильный по типу объект Metadata. После определения всех символов и типов распределений создается vi::TypedVarInfo, где vi.metadata — это кортеж NamedTuple, в котором каждый символ из P сопоставляется со специальным экземпляром Metadata. Поэтому при условии, что выборка для каждого символа в P производится только из одного типа распределений, vi::TypedVarInfo будет иметь полностью конкретно типизированные поля, благодаря чему в Julia достигается максимальная производительность.