Плоские и вложенные структуры

Модели Flux имеют вложенную структуру с хранением параметров на множестве уровней. Иногда может требоваться их плоское представление, чтобы можно было использовать функции, принимающие только один вектор. Для этого служит функция destructure:

julia> model = Chain(Dense(2=>1, tanh), Dense(1=>1))
Chain(
  Dense(2 => 1, tanh),                  # 3 параметра
  Dense(1 => 1),                        # 2 параметра
)                   # Всего: 4 массива, 5 параметров, 276 байтов.

julia> flat, rebuild = Flux.destructure(model)
(Float32[0.863101, 1.2454957, 0.0, -1.6345707, 0.0], Restructure(Chain, ..., 5))

julia> rebuild(zeros(5))  # та же структура с новыми параметрами
Chain(
  Dense(2 => 1, tanh),                  # 3 параметра (все нулевые)
  Dense(1 => 1),                        # 2 параметра (все нулевые)
)                   # Всего: 4 массива, 5 параметров, 276 байтов.

Обе функции destructure и Restructure можно использовать в градиентных вычислениях. Например, следующий код вычисляет гессиан ∂²L/∂θᵢ∂θⱼ некоторой функции потерь с учетом всех параметров модели Flux. Полученная матрица имеет внедиагональные элементы и поэтому не может быть представлена в виде вложенной структуры:

julia> x = rand(Float32, 2, 16);

julia> grad = gradient(m -> sum(abs2, m(x)), model)  # вложенный градиент
((layers = ((weight = Float32[10.339018 11.379145], bias = Float32[22.845667], σ = nothing), (weight = Float32[-29.565302;;], bias = Float32[-37.644184], σ = nothing)),),)

julia> function loss(v::Vector)
         m = rebuild(v)
         y = m(x)
         sum(abs2, y)
       end;

julia> gradient(loss, flat)  # малый градиент, те же числа
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184],)

julia> Zygote.hessian(loss, flat)  # вторая производная
5×5 Matrix{Float32}:
  -7.13131   -5.54714  -11.1393  -12.6504   -8.13492
  -5.54714   -7.11092  -11.0208  -13.9231   -9.36316
 -11.1393   -11.0208   -13.7126  -27.9531  -22.741
 -12.6504   -13.9231   -27.9531   18.0875   23.03
  -8.13492   -9.36316  -22.741    23.03     32.0

julia> Flux.destructure(grad)  # работает не только с моделями
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))
Совместимость: Flux ≤ 0.12

В прежних версиях Flux была совершенно другая реализация destructure, в которой было множество ошибок (и которая почти не тестировалась). Множество комментариев в Интернете относятся именно к этой все еще существующей версии функции.

Все параметры

Функция destructure теперь находится в пакете Optimisers.jl. (Имейте в виду, что этот пакет не связан с подмодулем Flux.Optimisers! Эта путаница будет в дальнейшем устранена.)

# Optimisers.destructureFunction

destructure(model) -> vector, reconstructor

Копирует все параметры trainable, isnumeric в модели в вектор и также возвращает функцию, которая обращает это преобразование. Поддается дифференцированию.

Пример

julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))

julia> re([3, 5, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))

Если в model имеются разные числовые типы, они продвигаются для создания vector и обычно восстанавливаются посредством Restructure. Такое восстановление выполняется по правилам ChainRulesCore.ProjectTo, поэтому предполагает восстановление точности чисел с плавающей запятой, но допускает использование и более экзотических чисел, таких как ForwardDiff.Dual.

Если model содержит только массивы в GPU, то vector будет также размещаться в GPU. В настоящее время одновременное размещение массивов в GPU и ЦП приводит к неопределенному поведению.

# Optimisers.trainableFunction

trainable(x::Layer) -> NamedTuple

Его можно перегрузить, чтобы оптимизаторы игнорировали некоторые поля каждого слоя Layer, которые иначе содержали бы обучаемые параметры.

Делать это требуется очень редко. Поля struct Layer, которые содержат функции или целые числа, например размеры, все равно всегда игнорируются. Перегружать trainable необходимо только в том случае, если следует оптимизировать лишь некоторые числовые массивы, но не все.

Значение по умолчанию — Functors.children(x). Обычно это именованный кортеж всех полей, а в trainable(x) должно содержаться их подмножество.

# Optimisers.isnumericFunction

isnumeric(x) -> Bool

Возвращает true для любых параметров, которые должны корректироваться пакетом Optimisers.jl, в частности для массивов нецелых чисел. Для всех остальных типов возвращает false.

Также требует выполнения условия Functors.isleaf(x) == true, чтобы обрабатывался, например, родительский объект транспонированной матрицы, а не оболочка.

Все слои

Другая разновидность плоского представления вложенной модели предоставляется командой modules. Она извлекает список всех слоев:

# Flux.modulesFunction

modules(m)

Возвращает итератор по неконечным объектам, к которым можно получить доступ путем рекурсивного выполнения m для дочерних объектов, заданных с помощью functor.

Полезно для применения функции (например, регуляризатора) к определенным модулям или подмножествам параметров (например, весам, но не смещениям).

Примеры

julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu));

julia> m2 = Chain(m1, Dense(64, 10))
Chain(
  Chain(
    Dense(784 => 64),                   # 50_240 параметров
    BatchNorm(64, relu),                # 128 параметров плюс 128
  ),
  Dense(64 => 10),                      # 650 параметров
)         # Всего: 6 обучаемых массивов, 51_018 параметров
          # плюс 2 необучаемых, 128 параметров, суммарный размер 200,312 КиБ.

julia> Flux.modules(m2)
7-element Vector{Any}:
 Chain(Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))  # 51_018 параметров плюс 128 необучаемых
 (Chain(Dense(784 => 64), BatchNorm(64, relu)), Dense(64 => 10))
 Chain(Dense(784 => 64), BatchNorm(64, relu))  # 50_368 параметров плюс 128 необучаемых
 (Dense(784 => 64), BatchNorm(64, relu))
 Dense(784 => 64)    # 50_240 параметров
 BatchNorm(64, relu)  # 128 параметров плюс 128 необучаемых
 Dense(64 => 10)     # 650 параметров

julia> L2(m) = sum(sum(abs2, l.weight) for l in Flux.modules(m) if l isa Dense)
L2 (generic function with 1 method)

julia> L2(m2) isa Float32
true

Сохранение и загрузка

# Flux.stateFunction

state(x)

Возвращает объект с той же вложенной структурой, что и у x, в соответствии с Functors.children, но составленный только из базовых контейнеров (например, именованных кортежей, кортежей, массивов и словарей).

Помимо обучаемых и необучаемых массивов, состояние будет содержать конечные узлы, не являющиеся массивами, такие как числа, символы, строки и значения nothing. Типы конечных элементов, попадающих в состояние, в будущем, возможно, будут расширены.

Этот метод особенно полезен для сохранения и загрузки моделей, так как состояние содержит только простые типы данных, которые можно легко сериализировать.

Состояние можно передать в loadmodel! для восстановления модели.

Примеры

Копирование состояния в другую модель

julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones));

julia> s = Flux.state(m1)
(layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = ()), (weight = [1.0 1.0], bias = [0.0], σ = ())),)

julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false));  # веса представляют собой случайные числа

julia> Flux.loadmodel!(m2, s);

julia> m2[1].weight   # теперь веса у m2 те же, что и у m1
2×1 Matrix{Float32}:
 1.0
 1.0

julia> Flux.state(trainmode!(Dropout(0.2)))  # содержит p и активность, но не состояние генератора случайных чисел
(p = 0.2, dims = (), active = true, rng = ())

julia> Flux.state(BatchNorm(1))  # содержит необучаемые массивы μ, σ²
(λ = (), β = Float32[0.0], γ = Float32[1.0], μ = Float32[0.0], σ² = Float32[1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 1)

Сохранение и загрузка с помощью BSON

julia> using BSON

julia> BSON.@save "checkpoint.bson" model_state = s

julia> Flux.loadmodel!(m2, BSON.load("checkpoint.bson")[:model_state])

Сохранение и загрузка с помощью JLD2

julia> using JLD2

julia> JLD2.jldsave("checkpoint.jld2", model_state = s)

julia> Flux.loadmodel!(m2, JLD2.load("checkpoint.jld2", "model_state"))

# Flux.loadmodel!Function

loadmodel!(dst, src)

Копирует все параметры (обучаемые и необучаемые) из src в dst.

Выполняет рекурсивный обход одновременно dst и src с помощью Functors.children и вызывает copyto! для массивов параметров или выдает ошибку в случае несоответствия. Отличные от массивов элементы (например, функции активации) не копируются и не требуют соответствия. Нулевые векторы смещений и параметр bias=false считаются эквивалентными (дополнительные сведения см. в расширенной справке).

См. также описание Flux.state.

Примеры

julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0]))
Chain(
  Dense(5 => 2, tanh),                  # 12 параметров
  Dense(2 => 1),                        # 3 параметра
)                   # Всего: 4 массива, 15 параметров, 316 байтов.

julia> dst[1].weight ≈ ones(2, 5)  # по общему правилу
true

julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));

julia> Flux.loadmodel!(dst, src);

julia> dst[1].weight ≈ ones(2, 5)  # значения изменились
false

julia> iszero(dst[2].bias)
true

Расширенная справка

Выдает ошибку в следующих случаях:

  • у dst и src разные поля (на любом уровне);

  • конечные узлы в dst и src не совпадают по размеру;

  • копирование отличных от массивов значений в параметр-массив или из него (кроме описанных ниже неактивных параметров);

  • dst является «связанным» параметром (то есть ссылается на другой параметр), и в него несколько раз загружаются несоответствующие исходные значения.

Неактивные параметры можно кодировать с помощью логического значения false вместо массива. Если dst == false, а src представляет собой массив со всеми нулевыми элементами, ошибка не возникает (и значения не копируются); однако при попытке скопировать ненулевой массив в неактивный параметр произойдет ошибка. Аналогичным образом, копирование значения src false в любой массив dst допустимо, но копирование значения src true приведет к ошибке.