Документация Engee

Использование перехватчиков

Что такое перехватчики

Во время взаимодействия между агентами и средами часто необходимо собирать полезную информацию. Прямолинейным подходом является использование императивного программирования. Мы пишем код, который будет выполняться пошагово внутри цикла.

while true
    action = plan!(policy, env)
    act!(env, action)

    # прописываем здесь логику,
    # например сохранение параметров, регистрацию функции потерь, вычисление политики и т. д.
    check!(stop_condition, env, policy) && break
    is_terminated(env) && reset!(env)
end

Преимущество такого подхода — его полная ясность. Вы отвечаете за все выполняемые действия. И такой подход рекомендуется для новых пользователей, которые хотят опробовать различные компоненты пакета.

Другой подход — декларативное программирование. Вы описываете, что должно делаться во время эксперимента и когда. Затем вы совмещаете эту информацию с агентом и средой. Наконец, вы выполняете команду run, чтобы провести эксперимент. Это позволяет повторно использовать стандартные перехватчики и конвейеры выполнения, а не писать один и тот же код много раз. Во многих существующих пакетах Python для обучения с подкреплением для определения конвейера выполнения обычно применяется набор файлов конфигурации. Однако, на наш взгляд, в Julia это лишнее. Подход на основе декларативного программирования дает гораздо большую гибкость.

Следующий вопрос в том, как построить обработчик. Естественным путем представляется заключение прокомментированной части из приведенного выше псевдокода в функцию:

while true
    action = plan!(policy, env)
    act!(env, action)
    push!(hook, policy, env)
    check!(stop_condition, env, policy) && break
    is_terminated(env) && reset!(env)
end

Но иногда требуется более детальный контроль. Поэтому вызов обработчиков разделяется на несколько отдельных этапов:

Определение настраиваемого обработчика

По умолчанию при вызове push!(hook::AbstractHook, ::AbstractStage, policy, env) экземпляр AbstractHook ничего не делает. Поэтому при написании настраиваемого обработчика достаточно реализовать необходимую логику времени выполнения.

Например, предположим, что нужно регистрировать фактическое время каждого эпизода.

julia> using ReinforcementLearning


julia> import Base.push!


julia> Base.@kwdef mutable struct TimeCostPerEpisode <: AbstractHook
           t::UInt64 = time_ns()
           time_costs::Vector{UInt64} = []
       end
Main.TimeCostPerEpisode

julia> Base.push!(h::TimeCostPerEpisode, ::PreEpisodeStage, policy, env) = h.t = time_ns()


julia> Base.push!(h::TimeCostPerEpisode, ::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)


julia> h = TimeCostPerEpisode()
Main.TimeCostPerEpisode(0x00000133269cf1fa, UInt64[])

julia> run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h)
ERROR: MethodError: push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::RandomPolicy{Nothing, TaskLocalRNG}, ::CartPoleEnv{Float64, Int64}) is ambiguous.

Candidates:
  push!(h::Main.TimeCostPerEpisode, ::PreEpisodeStage, policy, env)
    @ Main REPL[4]:1
  push!(::AbstractHook, ::AbstractStage, ::AbstractPolicy, ::AbstractEnv)
    @ ReinforcementLearningCore ~/.julia/packages/ReinforcementLearningCore/BYdWk/src/core/hooks.jl:35

Possible fix, define
  push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::AbstractPolicy, ::AbstractEnv)

julia> h.time_costs
UInt64[]

Стандартные обработчики

Периодические задания

Иногда требуется выполнять некоторые функции периодически. Для этой задачи имеются два удобных обработчика:

Ниже представлены типичные способы их использования.

Вычисление политики во время обучения

julia> using Statistics: mean


julia> policy = RandomPolicy()
::RandomPolicy	
├─ action_space::Nothing => nothing	
└─ rng::TaskLocalRNG => TaskLocalRNG()	

julia> run(
           policy,
           CartPoleEnv(),
           StopAfterNEpisodes(100),
           DoEveryNEpisodes(;n=10) do t, policy, env
               # В реальных сценариях политика обычно инкапсулируется в агенте.
               # Нам необходимо извлечь внутреннюю политику для ее выполнения в режиме *исполнителя*.
               # Здесь для иллюстрации просто используется исходная политика.

               # Обратите внимание: чтобы исходная среда не засорялась, здесь создается
               # новый экземпляр CartPoleEnv.

               hook = TotalRewardPerEpisode(;is_display_on_exit=false)
               run(policy, CartPoleEnv(), StopAfterNEpisodes(10), hook)

               # Теперь можно вывести результат выполнения обработчика.
               println("avg reward at episode $t is: $(mean(hook.rewards))")
           end
       )
avg reward at episode 10 is: 21.6
avg reward at episode 20 is: 21.4
avg reward at episode 30 is: 28.5
avg reward at episode 40 is: 19.7
avg reward at episode 50 is: 17.5
avg reward at episode 60 is: 23.5
avg reward at episode 70 is: 20.2
avg reward at episode 80 is: 18.2
avg reward at episode 90 is: 19.4
avg reward at episode 100 is: 29.8
DoEveryNEpisodes{PostEpisodeStage, Main.var"#2#3"}(Main.var"#2#3"(), 10, 100)

Сохранение параметров

Для сохранения параметров политики рекомендуется использовать JLD2.jl.

julia> using ReinforcementLearning


julia> using JLD2
ERROR: ArgumentError: Package JLD2 not found in current path.
- Run `import Pkg; Pkg.add("JLD2")` to install the JLD2 package.

julia> env = RandomWalk1D()
# RandomWalk1D

## Traits

| Trait Type        |                Value |
|:----------------- | --------------------:|
| NumAgentStyle     |        SingleAgent() |
| DynamicStyle      |         Sequential() |
| InformationStyle  | PerfectInformation() |
| ChanceStyle       |      Deterministic() |
| RewardStyle       |     TerminalReward() |
| UtilityStyle      |         GeneralSum() |
| ActionStyle       |   MinimalActionSet() |
| StateStyle        | Observation{Int64}() |
| DefaultStateStyle | Observation{Int64}() |
| EpisodeStyle      |           Episodic() |

## Is Environment Terminated?

No

## State Space

`Base.OneTo(7)`

## Action Space

`Base.OneTo(2)`

## Current State

```
4
```

julia> ns, na = length(state_space(env)), length(action_space(env))
(7, 2)

julia> policy = Agent(
           QBasedPolicy(;
               learner = TDLearner(
                   TabularQApproximator(n_state = ns, n_action = na),
                   :SARS;
               ),
               explorer = EpsilonGreedyExplorer(ϵ_stable=0.01),
           ),
           Trajectory(
               CircularArraySARTSTraces(;
                   capacity = 1,
                   state = Int64 => (),
                   action = Int64 => (),
                   reward = Float64 => (),
                   terminal = Bool => (),
               ),
               DummySampler(),
               InsertSampleRatioController(),
           ),
       )
Agent{QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}}, Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}}(QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}}(TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}(TabularQApproximator{Matrix{Float64}}([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]), 1.0, 0.01, 0), EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}(0.01, 1.0, 0, 0, 1, TaskLocalRNG())), Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}(@NamedTuple{state::Int64, next_state::Int64, action::Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, reward::Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, terminal::Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}[], DummySampler(), InsertSampleRatioController(1.0, 1, 0, 0), identity))

julia> parameters_dir = mktempdir()
"/tmp/jl_mL3XYN"

julia> run(
           policy,
           env,
           StopAfterNSteps(10_000),
           DoEveryNSteps(n=1_000) do t, p, e
               ps = policy.policy.learner.approximator
               f = joinpath(parameters_dir, "parameters_at_step_$t.jld2")
               JLD2.@save f ps
               println("parameters at step $t saved to $f")
           end
       )
ERROR: LoadError: UndefVarError: `JLD2` not defined
in expression starting at REPL[7]:8