Документация Engee

Сохранение и загрузка моделей

Вы можете сохранить модели, чтобы их можно было загрузить и запустить в следующем сеансе. Во Flux существует несколько способов выполнения этих задач. Рекомендуемый способ, который является наиболее надежным для длительного хранения, заключается в использовании Flux.state в сочетании с форматом сериализации, например JLD2.jl или BSON.jl.

Сохраним модель:

julia> using Flux

julia> struct MyModel
           net
       end

julia> Flux.@functor MyModel

julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2)));

julia> model = MyModel()
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)))

julia> model_state = Flux.state(model);

julia> using JLD2

julia> jldsave("mymodel.jld2"; model_state)

Загрузим ее снова в новом сеансе с помощью Flux.loadmodel!:

julia> using Flux, JLD2

julia> model_state = JLD2.load("mymodel.jld2", "model_state");

julia> model = MyModel(); # Должно быть доступно определение MyModel

julia> Flux.loadmodel!(model, model_state);

Если параметры сохраненной модели хранятся в GPU, модель не будет загружаться в дальнейшем, если отсутствует поддержка GPU. Перед сохранением рекомендуется переместить модель в CPU с помощью cpu(model).

Копирование в контрольных точках

В ходе длительного обучения полезно периодически сохранять модель, чтобы можно было возобновить обучение в случае его прерывания (например, при отключении электричества).

julia> using Flux: throttle

julia> using JLD2

julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2))
Chain(
  Dense(10 => 5, relu),                 # 55 параметров
  Dense(5 => 2),                        # 12 параметров
)                   # Всего: 4 массива, 67 параметров, 524 байта.

julia> for epoch in 1:10
          # ... обучение модели ...
          jldsave("model-checkpoint.jld2", model_state = Flux.state(m))
       end;

"model-checkpoint.jld2" будет обновляться каждую эпоху.

Вы можете повысить свой уровень работы, сохраняя ряд моделей в процессе обучения. Например,

jldsave("model-$(now()).jld2", model_state = Flux.state(m))

создаст ряд моделей типа "model-2018-03-06T02:57:10.41.jld2". Можно также сохранить текущие потери тестового набора, чтобы легко (например) вернуться к более старой копии модели, если она начнет переобучаться.

jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss())

Обратите внимание, что для возобновления обучения модели может потребоваться восстановить другие части цикла обучения с сохранением состояния. Возможные примеры: состояние оптимизатора и случайность, используемая для разделения исходных данных на обучающий и проверочный наборы.

Вы можете сохранить состояние оптимизатора вместе с моделью, чтобы возобновить обучение с того места, на котором остановились:

model = MyModel()
opt_state = Flux.setup(AdamW(), model)

# ... обучение модели ...

model_state = Flux.state(model)
jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state)

Сохранение моделей в виде структур Julia

Модели — это обычные структуры Julia, поэтому можно использовать любой формат хранения Julia, чтобы сохранить структуру как она есть, а не сохранять состояние, возвращаемое Flux.state. Для этого отлично подходит BSON.jl, поскольку он также может сохранять любые анонимные функции, которые иногда являются частью определения модели.

Сохраним модель:

julia> using Flux

julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2));

julia> using BSON: @save

julia> @save "mymodel.bson" model

Загрузим ее снова в новом сеансе:

julia> using Flux, BSON

julia> BSON.@load "mymodel.bson" model

julia> model
Chain(
  Dense(10 => 5, relu),                 # 55 параметров
  Dense(5 => 2),                        # 12 параметров
)                   # Всего: 4 массива, 67 параметров, 524 байта.

Сохранение моделей таким образом может привести к проблемам совместимости между версиями Julia и между версиями Flux при изменении некоторых внутренних компонентов слоев Flux. Поэтому этот способ не рекомендуется использовать для длительного хранения. Применяйте вместо него Flux.state.

В предыдущих версиях Flux предлагалось сохранять только веса моделей с помощью @save "mymodel.bson" params(model). Теперь это делать не рекомендуется и даже крайне нежелательно. При сохранении моделей таким образом будут сохраняться только обучаемые параметры, что приведет к некорректному поведению для таких слоев, как BatchNorm.