Документация Engee

Справка по API обучения

Новая версия обучающего кода Flux была написана как независимый пакет Optimisers.jl. Flux принадлежит только функция train!.

Пакет Optimisers предназначен для работы с неизменяемыми объектами. Но в настоящее время все модели Flux содержат массивы параметров (например, Array и CuArray), которые могут быть обновлены на месте. Благодаря этому

  • объекты, возвращаемые Optimisers.update!, можно игнорировать;

  • Flux определяет свою собственную версию setup, которая проверяет это предположение. (Вместо нее также можно использовать Optimisers.setup, так как они возвращают одно и то же.)

Новая реализация таких правил, как Adam, в Optimisers значительно отличается от старой реализации в Flux.Optimise. Во Flux 0.14 Flux.Adam() возвращает старый вариант с супертипом Flux.Optimise.AbstractOptimiser, но setup автоматически переводит его в новый аналог. Доступные правила приведены на странице Правила оптимизации. Подробные сведения о работе новых правил см. в документации по Optimisers.

opt_state = setup(rule, model)

Это версия функции Optimisers.setup, и это первый шаг перед использованием функции train!. Она отличается от функции Optimisers.setup тем, что

  • имеет одну дополнительную проверку на изменяемость (поскольку Flux ожидает, что модель будет изменяться на месте, тогда как Optimisers.jl предназначен для возврата обновленной модели);

  • имеет методы, которые принимают старые оптимизаторы Flux и преобразуют их. (Старый Flux.Optimise.Adam и новый Optimisers.Adam являются разными типами.)

Совместимость: New

Эта функция появилась во Flux 0.13.9. Она не использовалась в старом «неявном» интерфейсе, использующем модуль Flux.Optimise и Flux.params.

Пример

julia> model = Dense(2=>1, leakyrelu; init=ones);

julia> opt_state = Flux.setup(Momentum(0.1), model)  # кодирует оптимизатор и его состояние
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ())

julia> x1, y1 = [0.2, -0.3], [0.4];  # использует те же данные для двух шагов:

julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt_state) do m, x, y
         sum(abs.(m(x) .- y)) * 100
       end

julia> model.bias  # был равен нулю, изменен Flux.train!
1-element Vector{Float64}:
 10.19

julia> opt_state  # изменен Flux.train!
(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ())
train!(loss, model, data, opt_state)

Использует функцию потерь (loss) и обучающие данные (data) для улучшения параметров модели (model) в соответствии с определенным правилом оптимизации, закодированным в opt_state. Итерирует данные (data) один раз, вычисляя для каждого d in data либо loss(model, d...), если d isa Tuple, либо loss(model, d) для других d.

Например, с этими определениями…​

data = [(x1, y1), (x2, y2), (x3, y3)]

loss3(m, x, y) = norm(m(x) .- y)        # модель является первым аргументом

opt_state = Flux.setup(Adam(), model)   # явная настройка импульсов оптимизатора

…​вызов Flux.train!(loss3, model, data, opt_state) запускает цикл, подобный этому:

for d in data
    ∂L∂m = gradient(loss3, model, d...)[1]
    update!(opt_state, model, ∂L∂m)
end

Вы также можете написать этот цикл самостоятельно, если вам нужна большая гибкость. По этой причине функция train! не отличается высокой расширяемостью. Она добавляет лишь несколько возможностей в приведенный выше цикл:

  • остановка с выводом ошибки DomainError, если потери бесконечны, или с NaN в любой точке;

  • отображение индикатора хода выполнения с помощью @withprogress.

compat "New" Этот метод появился во Flux 0.13.9. Он существенно отличается от метода, используемого в версиях Flux не выше 0.13: * Теперь он принимает саму модель (model), а не результат Flux.params. (Это делается для того, чтобы отказаться от «неявной» обработки параметров в Zygote с помощью Grads.) * Вместо того чтобы функция loss принимала только данные, теперь она должна принимать и саму модель (model) в качестве первого аргумента. * opt_state должен быть результатом Flux.setup. При использовании оптимизатора, такого как Adam(), без этого шага должно выводиться предупреждение. * Функции обратных вызовов не поддерживаются. (Но в приведенный выше цикл for можно включить любой код.)

Optimisers.update!(tree, model, gradient) -> (tree, model)

Использует оптимизатор и градиент для изменения обучаемых параметров в модели. Возвращает улучшенную модель и состояния оптимизатора, необходимые для следующего обновления. Источником начального дерева состояний является setup.

Эта функция используется точно так же, как и update, но поскольку она может изменять массивы внутри старой модели (и старое состояние), она будет быстрее для моделей обычных Array или CuArray. Однако не стоит полагаться на то, что старая модель будет полностью обновлена, поэтому стоит использовать возвращенную модель. (Исходное дерево состояний всегда изменяется, поскольку каждый лист (Leaf) является изменяемым.)

Пример

julia> using StaticArrays, Zygote, Optimisers

julia> m = (x = [1f0, 2f0], y = SA[4f0, 5f0]);  # частично изменяемая модель

julia> t = Optimisers.setup(Momentum(1/30, 0.9), m)  # дерево состояний
(x = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]), y = Leaf(Momentum(0.0333333, 0.9), Float32[0.0, 0.0]))

julia> g = gradient(m -> sum(abs2.(m.x .+ m.y)), m)[1]  # структурный градиент
(x = Float32[10.0, 14.0], y = Float32[10.0, 14.0])

julia> t2, m2 = Optimisers.update!(t, m, g);

julia> m2  # после выполнения функций update или update! это новая модель
(x = Float32[0.6666666, 1.5333333], y = Float32[3.6666667, 4.5333333])

julia> m2.x === m.x  # update! повторно использует этот массив для повышения эффективности
true

julia> m  # исходное дерево должно быть отброшено, может быть изменено, но гарантии нет
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])

julia> t == t2  # исходное дерево состояний гарантированно будет изменено
true

Функция train! использует @progress, что должно автоматически отображать индикатор хода выполнения в VSCode. Чтобы увидеть это в терминале, нужно установить TerminalLoggers.jl и следовать инструкциям по настройке.

Модификаторы оптимизации

Состояние, возвращаемое функцией setup, можно изменить, чтобы временно запретить обучение некоторых частей модели, или изменить скорость обучения или другой гиперпараметр. Для этого предназначены функции Flux.freeze!, Flux.thaw! и Flux.adjust!. Все они изменяют состояние (или его часть) и возвращают nothing.

Optimisers.adjust!(tree, η)

Изменяет состояние tree = setup(rule, model), чтобы изменить параметры правила оптимизации, не разрушая его сохраненное состояние. Обычно используется в середине обучения.

Может применяться к части модели, воздействуя только на соответствующую часть состояния tree.

Чтобы изменить только скорость обучения, укажите число η::Real.

Пример

julia> m = (vec = rand(Float32, 2), fun = sin);

julia> st = Optimisers.setup(Nesterov(), m)  # сохраненный импульс инициализируется нулем
(vec = Leaf(Nesterov(0.001, 0.9), Float32[0.0, 0.0]), fun = ())

julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing));  # с ненастоящим градиентом

julia> st
(vec = Leaf(Nesterov(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())

julia> Optimisers.adjust!(st, 0.123)  # изменим скорости обучения, сохраненный импульс остается нетронутым

julia> st
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())

Чтобы изменить другие параметры, функция adjust! также принимает именованные аргументы, соответствующие именам полей типа правила оптимизации.

julia> fieldnames(Adam)
(:eta, :beta, :epsilon)

julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad(10.0), Adam(0.001, (0.9, 0.999), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())

julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1)  # дельта работает с ClipGrad
(vec = Leaf(OptimiserChain(ClipGrad(11.1), Adam(0.001, (0.777, 0.909), 1.0e-8)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())

julia> Optimisers.adjust(st; beta = "no such field")  # автоматически игнорируется!
(vec = Leaf(Nesterov(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
Optimisers.freeze!(tree)

Временно изменяет состояние tree = setup(rule, model) так, что параметры не будут обновляться. Действие отменяется с помощью thaw!.

Может применяться к части модели, воздействуя только на соответствующую часть модели, например с помощью model::Chain. Для заморозки model.layers[1] следует вызвать freeze!(tree.layers[1]).

Пример

julia> m = (x = ([1.0], 2.0), y = [3.0]);

julia> s = Optimisers.setup(Momentum(), m);

julia> Optimisers.freeze!(s.x)

julia> Optimisers.update!(s, m, (x = ([pi], 10pi), y = [100pi]));  # с ненастоящим градиентом

julia> m
(x = ([1.0], 2.0), y = [-0.14159265358979312])

julia> s
(x = (Leaf(Momentum(0.01, 0.9), [0.0], frozen = true), ()), y = Leaf(Momentum(0.01, 0.9), [3.14159]))

julia> Optimisers.thaw!(s)

julia> s.x
(Leaf(Momentum(0.01, 0.9), [0.0]), ())
Optimisers.thaw!(tree)

Обратная функция для функции freeze!. Применяется ко всем параметрам, изменяя каждый Leaf(rule, state, frozen = true) на Leaf(rule, state, frozen = false).

Неявный стиль (Flux версий не выше 0.14)

Раньше работа с градиентами, обучением и правилами оптимизации во Flux происходила совсем по-другому. Новый стиль, описанный выше, в Zygote называется «явным», а старый — «неявным». Flux 0.13 и 0.14 являются переходными версиями, которые поддерживают оба варианта. Старый стиль будет удален во Flux 0.15.

Совместимость: How to upgrade

В сине-зеленых блоках в разделе, посвященном обучению, описываются изменения, необходимые для обновления старого кода.

Подробную информацию об интерфейсе для оптимизаторов в неявном стиле можно найти в руководстве по Flux 0.13.6.

Совместимость: Flux ≤ 0.12

Ранние версии Flux экспортировали параметры (params), что позволяло использовать неквалифицированную params(model) после using Flux. При этом возникал конфликт с большим количеством других пакетов, поэтому этот вариант был удален во Flux 0.13. Если вы получаете ошибку UndefVarError: params not defined, это, скорее всего, означает, что вы используете код для Flux 0.12 или более ранней версии в более поздней версии.

params(model)
params(layers...)

При наличии модели или определенных слоев из модели создает объект Params, указывающий на ее обучаемые параметры.

Можно использовать с функцией gradient (см. раздел, посвященный обучению) или как входные данные для функции Flux.train!.

Поведение params в пользовательских типах можно настроить с помощью Functors.@functor или Flux.trainable.

Примеры

julia> using Flux: params

julia> params(Chain(Dense(ones(2,3)), softmax))  # распаковывает модели Flux
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])

julia> bn = BatchNorm(2, relu)
BatchNorm(2, relu)  # 4 параметра, плюс 4 необучаемых

julia> params(bn)  # только обучаемые параметры
Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])

julia> params([1, 2, 3], [4])  # один или несколько массивов чисел
Params([[1, 2, 3], [4]])

julia> params([[1, 2, 3], [4]])  # распаковывает массив массивов
Params([[1, 2, 3], [4]])

julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin))  # игнорирует скаляры, распаковывает именованные кортежи
Params([[2 2], [3, 3, 3]])
update!(opt, p, g)
update!(opt, ps::Params, gs)

Выполните шаг обновления параметров ps (или одного параметра p) в соответствии с оптимизатором opt::AbstractOptimiser, и градиентов gs (градиента g).

В результате параметры изменяются, и внутреннее состояние оптимизатора может измениться. Градиент также может быть изменен.

Совместимость: Deprecated

Этот метод для неявных параметров (Params) (и AbstractOptimiser) будет удален из Flux 0.15. Явный метод update!(opt, model, grad) из Optimisers.jl останется.

train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
Использует функцию потерь (`loss`) и обучающие данные (`data`) для

улучшения параметров модели в соответствии с определенным правилом оптимизации opt.

Совместимость: Deprecated

Этот метод для неявных параметров (Params) будет удален из Flux 0.15. Он должен быть заменен явным методом train!(loss, model, data, opt).

Для каждого d in data сначала вычисляется градиент потерь (loss):

    gradient(() -> loss(d...), pars)  # если d является кортежем
    gradient(() -> loss(d), pars)     # в противном случае

Здесь pars создается при вызове функции Flux.params в модели. (Или просто в слоях, которые вы хотите обучить, например train!(loss, params(model[1:end-2]), data, opt).) Это «неявный» стиль работы с параметрами.

Затем этот градиент затем используется оптимизатором opt для обновления параметров:

    update!(opt, pars, grads)

Оптимизатор должен быть из модуля Flux.Optimise (см. Оптимизаторы). Различные оптимизаторы можно объединять с помощью Flux.Optimise.Optimiser.

Этот цикл обучения итерирует data один раз. Он остановится с выводом ошибки DomainError, если потеря является NaN или бесконечна.

Вы можете использовать train! внутри цикла, чтобы сделать это несколько раз, или, например, использовать Itertools.ncycle для создания более длинного итератора data.

Обратные вызовы

Обратные вызовы с помощью именованного аргумента cb. Например, следующее будет выводить training каждые 10 секунд (используя Flux.throttle):

    train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))

Несколько обратных вызовов можно передать в cb в виде массива.

Обратные вызовы

Неявная функция train! принимает дополнительный аргумент cb, который используется для обратных вызовов, чтобы вы могли наблюдать за процессом обучения. Например:

train!(objective, ps, data, opt, cb = () -> println("training"))

Обратные вызовы вызываются для каждого пакета обучающих данных. Этот процесс можно замедлить с помощью функции Flux.throttle(f, timeout), которая не позволяет вызывать f чаще, чем раз в timeout с.

Более типичный обратный вызов может выглядеть следующим образом:

test_x, test_y = # ... создание единого пакета тестовых данных ...
evalcb() = @show(loss(test_x, test_y))
throttled_cb = throttle(evalcb, 5)
for epoch in 1:20
  @info "Epoch $epoch"
  Flux.train!(objective, ps, data, opt, cb = throttled_cb)
end

Дополнительные сведения см. на странице о вспомогательных функциях обратных вызовов.