Использование перехватчиков
Что такое перехватчики
Во время взаимодействия между агентами и средами часто необходимо собирать полезную информацию. Прямолинейным подходом является использование императивного программирования. Мы пишем код, который будет выполняться пошагово внутри цикла.
while true
action = plan!(policy, env)
act!(env, action)
# прописываем здесь логику,
# например сохранение параметров, регистрацию функции потерь, вычисление политики и т. д.
check!(stop_condition, env, policy) && break
is_terminated(env) && reset!(env)
end
Преимущество такого подхода — его полная ясность. Вы отвечаете за все выполняемые действия. И такой подход рекомендуется для новых пользователей, которые хотят опробовать различные компоненты пакета.
Другой подход — декларативное программирование. Вы описываете, что должно делаться во время эксперимента и когда. Затем вы совмещаете эту информацию с агентом и средой. Наконец, вы выполняете команду run
, чтобы провести эксперимент. Это позволяет повторно использовать стандартные перехватчики и конвейеры выполнения, а не писать один и тот же код много раз. Во многих существующих пакетах Python для обучения с подкреплением для определения конвейера выполнения обычно применяется набор файлов конфигурации. Однако, на наш взгляд, в Julia это лишнее. Подход на основе декларативного программирования дает гораздо большую гибкость.
Следующий вопрос в том, как построить обработчик. Естественным путем представляется заключение прокомментированной части из приведенного выше псевдокода в функцию:
while true
action = plan!(policy, env)
act!(env, action)
push!(hook, policy, env)
check!(stop_condition, env, policy) && break
is_terminated(env) && reset!(env)
end
Но иногда требуется более детальный контроль. Поэтому вызов обработчиков разделяется на несколько отдельных этапов:
Определение настраиваемого обработчика
По умолчанию при вызове push!(hook::AbstractHook, ::AbstractStage, policy, env)
экземпляр AbstractHook
ничего не делает. Поэтому при написании настраиваемого обработчика достаточно реализовать необходимую логику времени выполнения.
Например, предположим, что нужно регистрировать фактическое время каждого эпизода.
julia> using ReinforcementLearning
julia> import Base.push!
julia> Base.@kwdef mutable struct TimeCostPerEpisode <: AbstractHook
t::UInt64 = time_ns()
time_costs::Vector{UInt64} = []
end
Main.TimeCostPerEpisode
julia> Base.push!(h::TimeCostPerEpisode, ::PreEpisodeStage, policy, env) = h.t = time_ns()
julia> Base.push!(h::TimeCostPerEpisode, ::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)
julia> h = TimeCostPerEpisode()
Main.TimeCostPerEpisode(0x00000133269cf1fa, UInt64[])
julia> run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h)
ERROR: MethodError: push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::RandomPolicy{Nothing, TaskLocalRNG}, ::CartPoleEnv{Float64, Int64}) is ambiguous.
Candidates:
push!(h::Main.TimeCostPerEpisode, ::PreEpisodeStage, policy, env)
@ Main REPL[4]:1
push!(::AbstractHook, ::AbstractStage, ::AbstractPolicy, ::AbstractEnv)
@ ReinforcementLearningCore ~/.julia/packages/ReinforcementLearningCore/BYdWk/src/core/hooks.jl:35
Possible fix, define
push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::AbstractPolicy, ::AbstractEnv)
julia> h.time_costs
UInt64[]
Периодические задания
Иногда требуется выполнять некоторые функции периодически. Для этой задачи имеются два удобных обработчика:
Ниже представлены типичные способы их использования.
Вычисление политики во время обучения
julia> using Statistics: mean
julia> policy = RandomPolicy()
[33m[39m::[34mRandomPolicy[39m [90m[39m
├─ [33maction_space[39m::[34mNothing[39m[35m => [39m[32mnothing[39m [90m[39m
└─ [33mrng[39m::[34mTaskLocalRNG[39m[35m => [39m[32mTaskLocalRNG()[39m [90m[39m
julia> run(
policy,
CartPoleEnv(),
StopAfterNEpisodes(100),
DoEveryNEpisodes(;n=10) do t, policy, env
# В реальных сценариях политика обычно инкапсулируется в агенте.
# Нам необходимо извлечь внутреннюю политику для ее выполнения в режиме *исполнителя*.
# Здесь для иллюстрации просто используется исходная политика.
# Обратите внимание: чтобы исходная среда не засорялась, здесь создается
# новый экземпляр CartPoleEnv.
hook = TotalRewardPerEpisode(;is_display_on_exit=false)
run(policy, CartPoleEnv(), StopAfterNEpisodes(10), hook)
# Теперь можно вывести результат выполнения обработчика.
println("avg reward at episode $t is: $(mean(hook.rewards))")
end
)
avg reward at episode 10 is: 21.6
avg reward at episode 20 is: 21.4
avg reward at episode 30 is: 28.5
avg reward at episode 40 is: 19.7
avg reward at episode 50 is: 17.5
avg reward at episode 60 is: 23.5
avg reward at episode 70 is: 20.2
avg reward at episode 80 is: 18.2
avg reward at episode 90 is: 19.4
avg reward at episode 100 is: 29.8
DoEveryNEpisodes{PostEpisodeStage, Main.var"#2#3"}(Main.var"#2#3"(), 10, 100)
Сохранение параметров
Для сохранения параметров политики рекомендуется использовать JLD2.jl.
julia> using ReinforcementLearning
julia> using JLD2
ERROR: ArgumentError: Package JLD2 not found in current path.
- Run `import Pkg; Pkg.add("JLD2")` to install the JLD2 package.
julia> env = RandomWalk1D()
# RandomWalk1D
## Traits
| Trait Type | Value |
|:----------------- | --------------------:|
| NumAgentStyle | SingleAgent() |
| DynamicStyle | Sequential() |
| InformationStyle | PerfectInformation() |
| ChanceStyle | Deterministic() |
| RewardStyle | TerminalReward() |
| UtilityStyle | GeneralSum() |
| ActionStyle | MinimalActionSet() |
| StateStyle | Observation{Int64}() |
| DefaultStateStyle | Observation{Int64}() |
| EpisodeStyle | Episodic() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(7)`
## Action Space
`Base.OneTo(2)`
## Current State
```
4
```
julia> ns, na = length(state_space(env)), length(action_space(env))
(7, 2)
julia> policy = Agent(
QBasedPolicy(;
learner = TDLearner(
TabularQApproximator(n_state = ns, n_action = na),
:SARS;
),
explorer = EpsilonGreedyExplorer(ϵ_stable=0.01),
),
Trajectory(
CircularArraySARTSTraces(;
capacity = 1,
state = Int64 => (),
action = Int64 => (),
reward = Float64 => (),
terminal = Bool => (),
),
DummySampler(),
InsertSampleRatioController(),
),
)
Agent{QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}}, Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}}(QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}}(TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}(TabularQApproximator{Matrix{Float64}}([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]), 1.0, 0.01, 0), EpsilonGreedyExplorer{:linear, false, TaskLocalRNG}(0.01, 1.0, 0, 0, 1, TaskLocalRNG())), Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}(@NamedTuple{state::Int64, next_state::Int64, action::Trace{CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, reward::Trace{CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, terminal::Trace{CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}[], DummySampler(), InsertSampleRatioController(1.0, 1, 0, 0), identity))
julia> parameters_dir = mktempdir()
"/tmp/jl_mL3XYN"
julia> run(
policy,
env,
StopAfterNSteps(10_000),
DoEveryNSteps(n=1_000) do t, p, e
ps = policy.policy.learner.approximator
f = joinpath(parameters_dir, "parameters_at_step_$t.jld2")
JLD2.@save f ps
println("parameters at step $t saved to $f")
end
)
ERROR: LoadError: UndefVarError: `JLD2` not defined
in expression starting at REPL[7]:8