Сохранение и загрузка моделей
Вы можете сохранить модели, чтобы их можно было загрузить и запустить в следующем сеансе. Во Flux существует несколько способов выполнения этих задач. Рекомендуемый способ, который является наиболее надежным для длительного хранения, заключается в использовании Flux.state
в сочетании с форматом сериализации, например JLD2.jl или BSON.jl.
Сохраним модель:
julia> using Flux
julia> struct MyModel
net
end
julia> Flux.@functor MyModel
julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2)));
julia> model = MyModel()
MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2)))
julia> model_state = Flux.state(model);
julia> using JLD2
julia> jldsave("mymodel.jld2"; model_state)
Загрузим ее снова в новом сеансе с помощью Flux.loadmodel!
:
julia> using Flux, JLD2
julia> model_state = JLD2.load("mymodel.jld2", "model_state");
julia> model = MyModel(); # Должно быть доступно определение MyModel
julia> Flux.loadmodel!(model, model_state);
Если параметры сохраненной модели хранятся в GPU, модель не будет загружаться в дальнейшем, если отсутствует поддержка GPU. Перед сохранением рекомендуется переместить модель в CPU с помощью |
Копирование в контрольных точках
В ходе длительного обучения полезно периодически сохранять модель, чтобы можно было возобновить обучение в случае его прерывания (например, при отключении электричества).
julia> using Flux: throttle
julia> using JLD2
julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2))
Chain(
Dense(10 => 5, relu), # 55 параметров
Dense(5 => 2), # 12 параметров
) # Всего: 4 массива, 67 параметров, 524 байта.
julia> for epoch in 1:10
# ... обучение модели ...
jldsave("model-checkpoint.jld2", model_state = Flux.state(m))
end;
"model-checkpoint.jld2"
будет обновляться каждую эпоху.
Вы можете повысить свой уровень работы, сохраняя ряд моделей в процессе обучения. Например,
jldsave("model-$(now()).jld2", model_state = Flux.state(m))
создаст ряд моделей типа "model-2018-03-06T02:57:10.41.jld2"
. Можно также сохранить текущие потери тестового набора, чтобы легко (например) вернуться к более старой копии модели, если она начнет переобучаться.
jldsave("model-$(now()).jld2", model_state = Flux.state(m), loss = testloss())
Обратите внимание, что для возобновления обучения модели может потребоваться восстановить другие части цикла обучения с сохранением состояния. Возможные примеры: состояние оптимизатора и случайность, используемая для разделения исходных данных на обучающий и проверочный наборы.
Вы можете сохранить состояние оптимизатора вместе с моделью, чтобы возобновить обучение с того места, на котором остановились:
model = MyModel()
opt_state = Flux.setup(AdamW(), model)
# ... обучение модели ...
model_state = Flux.state(model)
jldsave("checkpoint_epoch=42.jld2"; model_state, opt_state)
Сохранение моделей в виде структур Julia
Модели — это обычные структуры Julia, поэтому можно использовать любой формат хранения Julia, чтобы сохранить структуру как она есть, а не сохранять состояние, возвращаемое Flux.state
. Для этого отлично подходит BSON.jl, поскольку он также может сохранять любые анонимные функции, которые иногда являются частью определения модели.
Сохраним модель:
julia> using Flux
julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2));
julia> using BSON: @save
julia> @save "mymodel.bson" model
Загрузим ее снова в новом сеансе:
julia> using Flux, BSON
julia> BSON.@load "mymodel.bson" model
julia> model
Chain(
Dense(10 => 5, relu), # 55 параметров
Dense(5 => 2), # 12 параметров
) # Всего: 4 массива, 67 параметров, 524 байта.
Сохранение моделей таким образом может привести к проблемам совместимости между версиями Julia и между версиями Flux при изменении некоторых внутренних компонентов слоев Flux. Поэтому этот способ не рекомендуется использовать для длительного хранения. Применяйте вместо него |
В предыдущих версиях Flux предлагалось сохранять только веса моделей с помощью |