应用深度Q学习(dqn)以蛇游戏为例解决离散环境中的控制问题
在本文中,使用深度Q网络(Dqn)算法在Snake游戏中训练代理。 目的是演示如何使用强化学习方法来解决离散环境中的控制问题。
代理(蛇)必须学习如何收集食物,避免与墙壁和自己的身体碰撞。 环境的状态由一组特征描述,并且神经网络近似于评估每个可能动作的有用性的Q函数。
在学习过程中,使用重放记忆技术和目标网络来稳定收敛。 训练结束后,训练有素的代理进行测试,可视化的游戏和分析得到的结果。
代码块描述
我想从这一节开始讲一点背景。
蛇是一款经典的街机游戏,其实质是控制一条移动的线("蛇")在运动场上爬行。 主要目标是收集出现在球场上的食物,增加蛇的长度,同时避免与自己的尾巴和场地的边界碰撞。
游戏起源于20世纪70年代中期,这种类型的第一个已知游戏被认为是街机封锁,由Gremlin Interactive于1976年发布。 在接下来的几年中,各种平台出现了许多变体和克隆。
这款游戏在全世界广受欢迎。1997,当芬兰开发者Taneli Armanto创建了一个版本 Snake 对于手机诺基亚6110。 正是这个版本成为诺基亚的名片,并向全球数百万人介绍了这款游戏,为移动娱乐业奠定了基础。
现在让我们仔细看看我们算法的每个组件的实现。
using Random
using Statistics
using LinearAlgebra
比赛场地的结构和游戏的逻辑
mutable struct SnakeGame-存储游戏参数:字段大小,蛇段坐标,当前方向,食物位置,分数,游戏结束标志,没有食物的步数计数器以及食物的最后距离。
mutable struct SnakeGame
width::Int
height::Int
snake::Vector{Tuple{Int,Int}}
direction::Tuple{Int,Int}
food::Tuple{Int,Int}
score::Int
done::Bool
steps::Int
max_steps::Int
last_dist::Float64
end
构造函数 SnakeGame(width, height)-创建一个游戏,在该领域的中心三段的初始蛇,设置方向向右,生成第一餐和初始化服务领域。
function SnakeGame(width=25, height=25)
start_x = div(width, 2)
start_y = div(height, 2)
snake = [(start_x, start_y), (start_x-1, start_y), (start_x-2, start_y)]
direction = (1, 0)
game = SnakeGame(width, height, snake, direction, (0,0), 0, false, 0, width*height, 0.0)
spawn_food!(game)
game.last_dist = manhattan_dist(game.snake[1], game.food)
return game
end
manhattan_dist(a, b) = abs(a[1]-b[1]) + abs(a[2]-b[2])
spawn_food!-将食物放在一个随机的空笼子里(不被蛇占据)。 如果没有空闲单元格,游戏以胜利结束(代码中设置了标志)。 done).
function spawn_food!(game::SnakeGame)
empty_cells = [(x,y) for x in 2:game.width-1, y in 2:game.height-1
if !((x,y) in game.snake)]
if !isempty(empty_cells)
game.food = rand(empty_cells)
else
game.done = true
end
end
step!-模拟的主要步骤:
-接受代理的动作:0-前进,1-左转,2-右转(动作3保留为禁止)。
-计算新的头部位置。
-检查与墙壁或身体碰撞-在这种情况下,游戏结束,并返回-1罚款。
-如果蛇吃了食物:分数增加,步数计数器重置,产生新的食物,奖励返回+10。
-否则,尾巴被移除,并计算接近食物的奖励(dist_reward = ±0.1). 增加了一个小的研究奖金(0.05,10%概率)和-0.001的永久小罚金,以鼓励更快的解决方案。
-如果没有食物的步数超过最大值,游戏以-0.5的罚款结束。
function step!(game::SnakeGame, action::Int)
game.steps += 1
if action == 1
game.direction = (-game.direction[2], game.direction[1])
elseif action == 2
game.direction = (game.direction[2], -game.direction[1])
end
new_head = (game.snake[1][1] + game.direction[1],
game.snake[1][2] + game.direction[2])
if new_head[1] <= 1 || new_head[1] >= game.width ||
new_head[2] <= 1 || new_head[2] >= game.height ||
new_head in game.snake
game.done = true
return -1.0
end
pushfirst!(game.snake, new_head)
if new_head == game.food
game.score += 1
game.steps = 0
spawn_food!(game)
game.last_dist = manhattan_dist(new_head, game.food)
return 10.0
end
pop!(game.snake)
current_dist = manhattan_dist(new_head, game.food)
dist_reward = sign(game.last_dist - current_dist) * 0.1 # +0.1更近,-0.1更远
game.last_dist = current_dist
explore_bonus = rand() < 0.1 ? 0.05 : 0.0
if game.steps > game.max_steps
game.done = true
return -0.5
end
return dist_reward + explore_bonus - 0.001
end
get_state-生成条件符号的向量(8个数字):
-三个方向(向前,向左,向右)碰撞的危险-0或1。
-食物的方向相对于蛇的当前方向(转换为"前进/后退"和"右/左")。
-归一化蛇长度(length/50).
是当前步数与最大值的比值。
-到尾部的归一化距离。
function get_state(game::SnakeGame)
head = game.snake[1]
forward = game.direction
left = (-game.direction[2], game.direction[1])
right = (game.direction[2], -game.direction[1])
danger_f = is_collision(game, (head[1]+forward[1], head[2]+forward[2])) ? 1.0 : 0.0
danger_l = is_collision(game, (head[1]+left[1], head[2]+left[2])) ? 1.0 : 0.0
danger_r = is_collision(game, (head[1]+right[1], head[2]+right[2])) ? 1.0 : 0.0
dx = sign(game.food[1] - head[1])
dy = sign(game.food[2] - head[2])
if game.direction == (1, 0) # 向右
food_forward = dy == 0 ? 0.0 : (dy > 0 ? -1.0 : 1.0)
food_right = dx > 0 ? 1.0 : (dx < 0 ? -1.0 : 0.0)
elseif game.direction == (-1, 0) # 向左
food_forward = dy == 0 ? 0.0 : (dy > 0 ? -1.0 : 1.0)
food_right = dx < 0 ? 1.0 : (dx > 0 ? -1.0 : 0.0)
elseif game.direction == (0, -1) # 向上
food_forward = dx == 0 ? 0.0 : (dx > 0 ? 1.0 : -1.0)
food_right = dy < 0 ? 1.0 : (dy > 0 ? -1.0 : 0.0)
else # 向下
food_forward = dx == 0 ? 0.0 : (dx > 0 ? 1.0 : -1.0)
food_right = dy > 0 ? 1.0 : (dy < 0 ? -1.0 : 0.0)
end
tail_dist = length(game.snake) > 1 ? manhattan_dist(head, game.snake[end]) / (game.width+game.height) : 0.0
return Float64[
danger_f, danger_l, danger_r,
food_forward, food_right,
length(game.snake) / 50.0,
game.steps / game.max_steps, tail_dist]
end
is_collision-检查位置是否在场外或被蛇的身体占据。
function is_collision(game::SnakeGame, pos::Tuple{Int,Int})
return pos[1] <= 1 || pos[1] >= game.width ||
pos[2] <= 1 || pos[2] >= game.height ||
pos in game.snake
end
神经网络架构
我们的网络是作为一个多变的结构来实现的 mutable struct QNetwork,其中存储两个完全连接(密集)层的权重和偏移量。 选择这种架构是因为"Snake"任务不需要对原始像素进行分析-描述场上情况的紧凑特征向量就足够了。 两层足以近似离散环境中的Q函数。
网络结构:
输入层-接受8个特征的向量,该向量由函数形成 get_state (方向上的危险,食物相对于头部的位置,蛇的长度等。).
隐藏层-包含128个神经元具有ReLU激活功能。 这个数字被选择作为计算复杂性和网络记忆相当复杂的行为模式的能力之间的折衷。
输出层输出3个值*对应于每个可能动作的Q(s,a)估计值:前进,左转,右转。 值越高,则认为动作越有利于处于当前状态。
存储的参数:
W1-第一层大小的权重矩阵128×8(输入和隐藏层之间的链接)。b1-隐藏层的位移矢量(128个值)。W2-大小第二层的权重矩阵3×128(隐藏层和输出层之间的链接)。b2-输出层的位移矢量(3个值)。
权重使用Xavier方法初始化(系数为 sqrt(2/input)),这有助于避免在训练开始时渐变褪色或爆炸。 现在让我们看看它在代码中的外观。
mutable struct QNetwork
W1::Matrix{Float64}
b1::Vector{Float64}
W2::Matrix{Float64}
b2::Vector{Float64}
end
设计师 QNetwork(input, hidden, output) 使用Xavier方法初始化权重(系数为 sqrt(2/input)).
function QNetwork(input::Int, hidden::Int, output::Int)
W1 = randn(hidden, input) * sqrt(2.0/input)
b1 = zeros(hidden)
W2 = randn(output, hidden) * sqrt(2.0/hidden)
b2 = zeros(output)
return QNetwork(W1, b1, W2, b2)
end
relu(x) = max.(x, 0.0)
forward-执行前向传递:输入向量 s 它经过线性变换,ReLU,然后是没有非线性的第二层。 输出是三个动作的Q值。
function forward(q::QNetwork, s::Vector{Float64})
h = relu(q.W1 * s .+ q.b1)
return q.W2 * h .+ q.b2 # 3个动作的Q值
end
epsilon_greedy-根据e-greedy策略选择一个动作:具有概率 eps -随机动作,否则具有最大Q值的动作。
function epsilon_greedy(q::QNetwork, state::Vector{Float64}, eps::Float64)
if rand() < eps
return rand(0:2)
else
q_vals = forward(q, state)
return argmax(q_vals) - 1
end
end
DQN学习算法
深度Q网络(dqn)是一种强化学习算法,其中使用神经网络来近似效用函数Q(s,a)。 与表格式Q学习不同,DQN能够处理大而连续的状态空间。
算法的本质:神经网络接收环境的输入状态并输出所有可能动作的Q值。 行动的选择基于平衡勘探和开发的电子贪婪策略。
**train_dqn!**是主要的学习功能。 参数:集数(5000),学习率(0.001),折扣因子γ(0.95)。
-初始化:播放缓冲区 memory (尺寸10000),目标网络 target_q (主副本),初始ε=1.0,最小ε=0.01,衰减系数ε=0.995。
-每集:
-创建新游戏,获得初始状态。
-游戏还没有结束:
-基于电子贪婪策略选择操作。
-一个步骤完成,获得奖励,下一个状态和一个完成标志。
-经验存储在内存中。
-如果内存大小≥batch_size(32),则随机选择一个mini-batch,并针对每个转换:
-计算当前的Q值 q_current.
-从目标网络获得下一个状态的Q值 q_next.
-使用公式生成目标值:
target = r + γ * max(q_next) 如果状态不是终端,否则 target = r.
-梯度下降使用误差反向传播(梯度的手动计算)进行。
-过渡到下一个状态。
-在剧集结束时,分数记录在阵列中 scores,每日更新。
-每100集,目标网络更新(复制尺度)。
-每500集,显示最后500集的平均得分。
-该函数返回按情节得分的点数数组。
function train_dqn!(q::QNetwork, episodes=5000, lr=0.001, gamma=0.9)
scores = Int[]
memory = Tuple{Vector{Float64}, Int, Float64, Vector{Float64}, Bool}[]
batch_size = 32
memory_size = 10000
eps_start = 1.0
eps_end = 0.01
eps_decay = 0.995
target_q = QNetwork(8, 128, 3)
target_q.W1 .= q.W1; target_q.b1 .= q.b1
target_q.W2 .= q.W2; target_q.b2 .= q.b2
eps = eps_start
for episode in 1:episodes
game = SnakeGame()
state = get_state(game)
total_reward = 0.0
while !game.done
action = epsilon_greedy(q, state, eps)
reward = step!(game, action)
next_state = get_state(game)
done = game.done
push!(memory, (state, action, reward, next_state, done))
if length(memory) > memory_size
popfirst!(memory)
end
total_reward += reward
state = next_state
if length(memory) >= batch_size
batch = rand(memory, batch_size)
for (s, a, r, s_next, d) in batch
q_current = forward(q, s)
q_next = forward(target_q, s_next)
target = copy(q_current)
if d
target[a+1] = r
else
target[a+1] = r + gamma * maximum(q_next)
end
h = relu(q.W1 * s .+ q.b1)
q_val = q.W2 * h .+ q.b2
error = q_val - target
dW2 = error * h'
db2 = error
dh = q.W2' * error
dh[h .<= 0] .= 0
dW1 = dh * s'
db1 = dh
q.W1 .-= lr * dW1
q.b1 .-= lr * db1
q.W2 .-= lr * dW2
q.b2 .-= lr * db2
end
end
end
push!(scores, game.score)
eps = max(eps_end, eps * eps_decay)
if episode % 100 == 0
target_q.W1 .= q.W1; target_q.b1 .= q.b1
target_q.W2 .= q.W2; target_q.b2 .= q.b2
end
if episode % 500 == 0
recent = scores[max(1, end-499):end]
println("Episode $episode, Mean: $(round(mean(recent), digits=2)), Max: $(maximum(recent)), Eps: $(round(eps, digits=3))")
end
end
return scores
end
游戏的可视化
draw_game-创建矩阵 grid,其中每个像素对应于用于显示的颜色。
function draw_game(game::SnakeGame)
grid = fill(0.5, game.height, game.width)
grid[1, :] .= 0.0
grid[end, :] .= 0.0
grid[:, 1] .= 0.0
grid[:, end] .= 0.0
for (i, (x, y)) in enumerate(game.snake)
if i == 1
grid[y, x] = 1.0
else
grid[y, x] = 0.8 - (i/length(game.snake))*0.3
end
end
grid[game.food[2], game.food[1]] = 0.2
return grid
end
play_and_record-在训练有素的网络(ε=0)的控制下运行游戏,并将动画保存为GIF。 在每个步骤中,使用以下方法构建热图 heatmap 从情节。 收集帧,然后以指定的帧速率创建动画GIF。 函数返回总分。
function play_and_record(q::QNetwork, filename="snake_dqn.gif"; fps=10)
game = SnakeGame()
frames = []
while !game.done && length(frames) < 500
grid = draw_game(game)
p = heatmap(grid,
color=:viridis,
aspect_ratio=:equal,
axis=false,
colorbar=false,
title="Score: $(game.score) | Steps: $(game.steps)",
size=(600, 600),
clim=(0, 1))
push!(frames, p)
state = get_state(game)
action = epsilon_greedy(q, state, 0.0)
step!(game, action)
end
grid = draw_game(game)
p = heatmap(grid,
color=:viridis,
aspect_ratio=:equal,
axis=false,
colorbar=false,
title="FINAL Score: $(game.score)",
size=(600, 600),
clim=(0, 1))
push!(frames, p)
anim = @animate for f in frames
plot(f)
end
gif(anim, filename, fps=fps)
println("Saved: $filename | Score: $(game.score) | Frames: $(length(frames))")
return game.score
end
启动和分析结果
最后:
-创建一个Q-network实例,输入大小为8,隐藏层为128,输出为3。
-培训需要5000集。
-学习后,一个测试游戏开始保存Gif(snake_dqn.gif).
-显示最终统计数据:最终得分,最高学费账单,最后1000集的平均得分。
-绘制按情节及其平滑(窗口为100的移动平均线)曲线的分数变化图,并保存到文件中 training.png.
println("Initializing DQN...")
@time q_net = QNetwork(8, 128, 3)
println("Training...")
@time scores = train_dqn!(q_net, 5000, 0.001, 0.95)
println("\nTesting...")
@time final_score = play_and_record(q_net, "snake_dqn.gif")
println("\nResults: Final=$final_score, MaxTrain=$(maximum(scores)), MeanLast1000=$(mean(scores[max(1,end-999):end]))")
p = plot(scores, xlabel="Episode", ylabel="Score", title="DQN Training", legend=false, alpha=0.3)
function moving_average(v, n)
return [mean(v[max(1,i-n+1):i]) for i in 1:length(v)]
end
plot!(moving_average(scores, 100), linewidth=2, color=:red)
savefig(p, "training.png")
display(p)
在训练期间,平均得分有所增加:
-到第500集,平均得分为~1.48,最高得分为9。
-到第1000集,平均得分增加到~9.55,最高得分为41。
-此外,平均得分稳定在13-15左右,定期达到41的高点(甚至在训练期间达到43)。
最终测试显示得分26 轨迹长度为501步,这表明代理已经学会了如何收集食物并避免碰撞,尽管它还没有达到理论最大值(所有625个细胞)。 最高训练分数43 它证实了代理人有时玩很长时间游戏的能力,并且过去1,000集(13.64)的平均得分表明一贯良好,但不是出色的游戏水平。
结论
为蛇游戏开发的DQN实现成功地训练了一个能够收集食物和避开障碍物的代理人。 使用重放记忆和目标网络使我们能够稳定训练。 对结果的分析表明,代理的平均得分从零增加到每5,000集~15分,最高达到48。 GIF上的可视化展示了训练有素的蛇的智能行为。 所提出的体系结构可以改进,例如,通过增加网络的大小,添加用于处理视觉状态的卷积层,或者使用更高级的方法(Double DQN,Dueling DQN)。 总的来说,该项目清楚地说明了深度学习方法在离散控制任务中的应用。