Использование перехватчиков
Что такое перехватчики
Во время взаимодействия между агентами и средами часто необходимо собирать полезную информацию. Прямолинейным подходом является использование императивного программирования. Мы пишем код, который будет выполняться пошагово внутри цикла.
while true
    action = plan!(policy, env)
    act!(env, action)
    # прописываем здесь логику,
    # например сохранение параметров, регистрацию функции потерь, вычисление политики и т. д.
    check!(stop_condition, env, policy) && break
    is_terminated(env) && reset!(env)
endПреимущество такого подхода — его полная ясность. Вы отвечаете за все выполняемые действия. И такой подход рекомендуется для новых пользователей, которые хотят опробовать различные компоненты пакета.
Другой подход — декларативное программирование. Вы описываете, что должно делаться во время эксперимента и когда. Затем вы совмещаете эту информацию с агентом и средой. Наконец, вы выполняете команду run, чтобы провести эксперимент. Это позволяет повторно использовать стандартные перехватчики и конвейеры выполнения, а не писать один и тот же код много раз. Во многих существующих пакетах Python для обучения с подкреплением для определения конвейера выполнения обычно применяется набор файлов конфигурации. Однако, на наш взгляд, в Julia это лишнее. Подход на основе декларативного программирования дает гораздо большую гибкость.
Следующий вопрос в том, как построить обработчик. Естественным путем представляется заключение прокомментированной части из приведенного выше псевдокода в функцию:
while true
    action = plan!(policy, env)
    act!(env, action)
    push!(hook, policy, env)
    check!(stop_condition, env, policy) && break
    is_terminated(env) && reset!(env)
endНо иногда требуется более детальный контроль. Поэтому вызов обработчиков разделяется на несколько отдельных этапов:
Определение настраиваемого обработчика
По умолчанию при вызове push!(hook::AbstractHook, ::AbstractStage, policy, env) экземпляр AbstractHook ничего не делает. Поэтому при написании настраиваемого обработчика достаточно реализовать необходимую логику времени выполнения.
Например, предположим, что нужно регистрировать фактическое время каждого эпизода.
julia> using ReinforcementLearning
julia> import Base.push!
julia> Base.@kwdef mutable struct TimeCostPerEpisode <: AbstractHook
           t::UInt64 = time_ns()
           time_costs::Vector{UInt64} = []
       end
Main.TimeCostPerEpisode
julia> Base.push!(h::TimeCostPerEpisode, ::PreEpisodeStage, policy, env) = h.t = time_ns()
julia> Base.push!(h::TimeCostPerEpisode, ::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)
julia> h = TimeCostPerEpisode()
Main.TimeCostPerEpisode(0x00000133269cf1fa, UInt64[])
julia> run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h)
ERROR: MethodError: push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::RandomPolicy{Nothing, TaskLocalRNG}, ::CartPoleEnv{Float64, Int64}) is ambiguous.
Candidates:
  push!(h::Main.TimeCostPerEpisode, ::PreEpisodeStage, policy, env)
    @ Main REPL[4]:1
  push!(::AbstractHook, ::AbstractStage, ::AbstractPolicy, ::AbstractEnv)
    @ ReinforcementLearningCore ~/.julia/packages/ReinforcementLearningCore/BYdWk/src/core/hooks.jl:35
Possible fix, define
  push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::AbstractPolicy, ::AbstractEnv)
julia> h.time_costs
UInt64[]Периодические задания
Иногда требуется выполнять некоторые функции периодически. Для этой задачи имеются два удобных обработчика:
Ниже представлены типичные способы их использования.
Вычисление политики во время обучения
julia> using Statistics: mean
julia> policy = RandomPolicy()
[33m[39m::[34mRandomPolicy[39m	[90m[39m
├─ [33maction_space[39m::[34mNothing[39m[35m => [39m[32mnothing[39m	[90m[39m
└─ [33mrng[39m::[34mTaskLocalRNG[39m[35m => [39m[32mTaskLocalRNG()[39m	[90m[39m
julia> run(
           policy,
           CartPoleEnv(),
           StopAfterNEpisodes(100),
           DoEveryNEpisodes(;n=10) do t, policy, env
               # В реальных сценариях политика обычно инкапсулируется в агенте.
               # Нам необходимо извлечь внутреннюю политику для ее выполнения в режиме *исполнителя*.
               # Здесь для иллюстрации просто используется исходная политика.
               # Обратите внимание: чтобы исходная среда не засорялась, здесь создается
               # новый экземпляр CartPoleEnv.
               hook = TotalRewardPerEpisode(;is_display_on_exit=false)
               run(policy, CartPoleEnv(), StopAfterNEpisodes(10), hook)
               # Теперь можно вывести результат выполнения обработчика.
               println("avg reward at episode $t is: $(mean(hook.rewards))")
           end
       )
avg reward at episode 10 is: 21.6
avg reward at episode 20 is: 21.4
avg reward at episode 30 is: 28.5
avg reward at episode 40 is: 19.7
avg reward at episode 50 is: 17.5
avg reward at episode 60 is: 23.5
avg reward at episode 70 is: 20.2
avg reward at episode 80 is: 18.2
avg reward at episode 90 is: 19.4
avg reward at episode 100 is: 29.8
DoEveryNEpisodes{PostEpisodeStage, Main.var"#2#3"}(Main.var"#2#3"(), 10, 100)Сохранение параметров
Для сохранения параметров политики рекомендуется использовать JLD2.jl.
julia> using ReinforcementLearning
julia> using JLD2
ERROR: ArgumentError: Package JLD2 not found in current path.
- Run `import Pkg; Pkg.add("JLD2")` to install the JLD2 package.
julia> env = RandomWalk1D()
# RandomWalk1D
## Traits
| Trait Type        |                Value |
|:----------------- | --------------------:|
| NumAgentStyle     |        SingleAgent() |
| DynamicStyle      |         Sequential() |
| InformationStyle  | PerfectInformation() |
| ChanceStyle       |      Deterministic() |
| RewardStyle       |     TerminalReward() |
| UtilityStyle      |         GeneralSum() |
| ActionStyle       |   MinimalActionSet() |
| StateStyle        | Observation{Int64}() |
| DefaultStateStyle | Observation{Int64}() |
| EpisodeStyle      |           Episodic() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(7)`
## Action Space
`Base.OneTo(2)`
## Current State
```
4
```
julia> ns, na = length(state_space(env)), length(action_space(env))
(7, 2)
julia> policy = Agent(
           QBasedPolicy(;
               learner = TDLearner(
                   TabularQApproximator(n_state = ns, n_action = na),
                   :SARS;
               ),
               explorer = EpsilonGreedyExplorer(ϵ_stable=0.01),
           ),
           Trajectory(
               CircularArraySARTSTraces(;
                   capacity = 1,
                   state = Int64 => (),
                   action = Int64 => (),
                   reward = Float64 => (),
                   terminal = Bool => (),
               ),
               DummySampler(),
               InsertSampleRatioController(),
           ),
       )
Agent{QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}}, Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}}(QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}}(TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}(TabularQApproximator{Matrix{Float64}}([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]), 1.0, 0.01, 0), EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}(0.01, 1.0, 0, 0, 1, TaskLocalRNG())), Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}(@NamedTuple{state::Int64, next_state::Int64, action::Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, reward::Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, terminal::Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}[], DummySampler(), InsertSampleRatioController(1.0, 1, 0, 0), identity))
julia> parameters_dir = mktempdir()
"/tmp/jl_mL3XYN"
julia> run(
           policy,
           env,
           StopAfterNSteps(10_000),
           DoEveryNSteps(n=1_000) do t, p, e
               ps = policy.policy.learner.approximator
               f = joinpath(parameters_dir, "parameters_at_step_$t.jld2")
               JLD2.@save f ps
               println("parameters at step $t saved to $f")
           end
       )
ERROR: LoadError: UndefVarError: `JLD2` not defined
in expression starting at REPL[7]:8