Сообщество Engee

Змейка

作者
avatar-yurevyurev
Notebook

Применение глубокого Q-обучения (DQN) для решения задачи управления в дискретной среде на примере игры "Змейка"

В данной работе реализовано обучение агента игре «Змейка» с использованием алгоритма глубокого Q-обучения (Deep Q-Network, DQN). Целью является демонстрация применения методов обучения с подкреплением для решения задачи управления в дискретной среде.

Агент (змейка) должен научиться собирать еду, избегая столкновений со стенами и собственным телом. Состояние среды описывается набором признаков, а нейронная сеть аппроксимирует Q-функцию, оценивающую полезность каждого возможного действия.

В процессе обучения используется техника воспроизведения опыта (replay memory) и целевая сеть (target network) для стабилизации сходимости. После обучения проводится тестирование обученного агента с визуализацией игрового процесса и анализом полученных результатов.

мини.png

Описание кода по блокам

Этот раздел я бы хотел начать с небольшой предистории.

Змейка — это классическая аркадная игра, суть которой заключается в управлении движущейся линией («змеёй»), ползающей по игровому полю. Основная цель — собирать появляющуюся на поле еду, увеличивая длину змеи, и при этом избегать столкновений с собственным хвостом и границами поля.

Игра зародилась в середине 1970-х годов. Первой известной игрой этого жанра считается аркадный автомат Blockade, выпущенный компанией Gremlin Interactive в 1976 году . В последующие годы появилось множество вариаций и клонов для различных платформ.

Всемирную популярность игра обрела в 1997 году, когда финский разработчик Танели Арманто создал версию Snake для мобильного телефона Nokia 6110 . Именно эта версия стала визитной карточкой Nokia и познакомила с игрой миллионы людей по всему миру, заложив основы индустрии мобильных развлечений.

Теперь подробно разберем реализацию каждого компонента нашего алгоритма.

In [ ]:
using Random
using Statistics
using LinearAlgebra

Структура игрового поля и логика игры

mutable struct SnakeGame – хранит параметры игры: размеры поля, координаты сегментов змейки, текущее направление, положение еды, счёт, флаг окончания игры, счётчик шагов без еды и последнее расстояние до еды.

In [ ]:
mutable struct SnakeGame
    width::Int
    height::Int
    snake::Vector{Tuple{Int,Int}}
    direction::Tuple{Int,Int}
    food::Tuple{Int,Int}
    score::Int
    done::Bool
    steps::Int
    max_steps::Int
    last_dist::Float64
end

Конструктор SnakeGame(width, height) – создаёт игру с начальной змейкой из трёх сегментов в центре поля, устанавливает направление вправо, генерирует первую еду и инициализирует служебные поля.

In [ ]:
function SnakeGame(width=25, height=25)
    start_x = div(width, 2)
    start_y = div(height, 2)
    snake = [(start_x, start_y), (start_x-1, start_y), (start_x-2, start_y)]
    direction = (1, 0)
    
    game = SnakeGame(width, height, snake, direction, (0,0), 0, false, 0, width*height, 0.0)
    spawn_food!(game)
    game.last_dist = manhattan_dist(game.snake[1], game.food)
    return game
end
Out[0]:
SnakeGame
In [ ]:
manhattan_dist(a, b) = abs(a[1]-b[1]) + abs(a[2]-b[2])
Out[0]:
manhattan_dist (generic function with 1 method)

spawn_food! – размещает еду в случайной свободной клетке (не занятой змейкой). Если свободных клеток нет, игра завершается победой (в коде устанавливается флаг done).

In [ ]:
function spawn_food!(game::SnakeGame)
    empty_cells = [(x,y) for x in 2:game.width-1, y in 2:game.height-1 
                   if !((x,y) in game.snake)]
    if !isempty(empty_cells)
        game.food = rand(empty_cells)
    else
        game.done = true
    end
end
Out[0]:
spawn_food! (generic function with 1 method)

step! – основной шаг симуляции:

  • Принимает действие агента: 0 – вперёд, 1 – поворот налево, 2 – направо (действие 3 зарезервировано как запрещённое).
  • Вычисляет новую позицию головы.
  • Проверяет столкновение со стеной или телом – в этом случае игра завершается, возвращается штраф –1.
  • Если змейка съедает еду: увеличивается счёт, сбрасывается счётчик шагов, генерируется новая еда, возвращается награда +10.
  • В противном случае удаляется хвост, и вычисляется награда за приближение к еде (dist_reward = ±0.1). Добавляется небольшой бонус за исследование (0.05 с вероятностью 10%) и постоянный малый штраф –0.001 для стимулирования более быстрых решений.
  • Если число шагов без еды превышает максимум, игра завершается со штрафом –0.5.
In [ ]:
function step!(game::SnakeGame, action::Int)
    game.steps += 1
    if action == 1
        game.direction = (-game.direction[2], game.direction[1])
    elseif action == 2
        game.direction = (game.direction[2], -game.direction[1])
    end
    new_head = (game.snake[1][1] + game.direction[1], 
                game.snake[1][2] + game.direction[2])
    
    if new_head[1] <= 1 || new_head[1] >= game.width || 
       new_head[2] <= 1 || new_head[2] >= game.height ||
       new_head in game.snake
        game.done = true
        return -1.0
    end
    pushfirst!(game.snake, new_head)
    if new_head == game.food
        game.score += 1
        game.steps = 0
        spawn_food!(game)
        game.last_dist = manhattan_dist(new_head, game.food)
        return 10.0
    end
    pop!(game.snake)
    
    current_dist = manhattan_dist(new_head, game.food)
    dist_reward = sign(game.last_dist - current_dist) * 0.1  # +0.1 ближе, -0.1 дальше
    game.last_dist = current_dist
    explore_bonus = rand() < 0.1 ? 0.05 : 0.0
    if game.steps > game.max_steps
        game.done = true
        return -0.5
    end
    return dist_reward + explore_bonus - 0.001
end
Out[0]:
step! (generic function with 302 methods)

get_state – формирует вектор признаков состояния (8 чисел):

  • Опасность столкновения в трёх направлениях (вперёд, налево, направо) – 0 или 1.
  • Направление на еду относительно текущего направления змейки (преобразовано во «вперёд/назад» и «вправо/влево»).
  • Нормализованная длина змейки (length/50).
  • Отношение текущего числа шагов к максимальному.
  • Нормализованное расстояние до хвоста.
In [ ]:
function get_state(game::SnakeGame)
    head = game.snake[1]
    forward = game.direction
    left = (-game.direction[2], game.direction[1])
    right = (game.direction[2], -game.direction[1])
    danger_f = is_collision(game, (head[1]+forward[1], head[2]+forward[2])) ? 1.0 : 0.0
    danger_l = is_collision(game, (head[1]+left[1], head[2]+left[2])) ? 1.0 : 0.0
    danger_r = is_collision(game, (head[1]+right[1], head[2]+right[2])) ? 1.0 : 0.0
    dx = sign(game.food[1] - head[1])
    dy = sign(game.food[2] - head[2])
    if game.direction == (1, 0)   # вправо
        food_forward = dy == 0 ? 0.0 : (dy > 0 ? -1.0 : 1.0)
        food_right = dx > 0 ? 1.0 : (dx < 0 ? -1.0 : 0.0)
    elseif game.direction == (-1, 0)  # влево
        food_forward = dy == 0 ? 0.0 : (dy > 0 ? -1.0 : 1.0)
        food_right = dx < 0 ? 1.0 : (dx > 0 ? -1.0 : 0.0)
    elseif game.direction == (0, -1)  # вверх
        food_forward = dx == 0 ? 0.0 : (dx > 0 ? 1.0 : -1.0)
        food_right = dy < 0 ? 1.0 : (dy > 0 ? -1.0 : 0.0)
    else  # вниз
        food_forward = dx == 0 ? 0.0 : (dx > 0 ? 1.0 : -1.0)
        food_right = dy > 0 ? 1.0 : (dy < 0 ? -1.0 : 0.0)
    end
    tail_dist = length(game.snake) > 1 ? manhattan_dist(head, game.snake[end]) / (game.width+game.height) : 0.0
    return Float64[
        danger_f, danger_l, danger_r,
        food_forward, food_right,
        length(game.snake) / 50.0,
        game.steps / game.max_steps, tail_dist]
end
Out[0]:
get_state (generic function with 1 method)

is_collision – проверяет, находится ли позиция за пределами поля или занята телом змейки.

In [ ]:
function is_collision(game::SnakeGame, pos::Tuple{Int,Int})
    return pos[1] <= 1 || pos[1] >= game.width || 
           pos[2] <= 1 || pos[2] >= game.height ||
           pos in game.snake
end
Out[0]:
is_collision (generic function with 1 method)

Архитектура нейронной сети

Наша сеть реализована в виде изменяемой структуры mutable struct QNetwork, которая хранит веса и смещения двух полносвязных (Dense) слоёв. Такая архитектура выбрана потому, что для задачи «Змейка» не нужен анализ сырых пикселей — достаточно компактного вектора признаков, описывающего ситуацию на поле. Двух слоёв вполне хватает для аппроксимации Q-функции в нашей дискретной среде.

Структура сети:

  • Входной слой — принимает вектор из 8 признаков, которые формирует функция get_state (опасность по направлениям, положение еды относительно головы, длина змейки и т.д.).
  • Скрытый слой — содержит 128 нейронов с функцией активации ReLU. Это число выбрано как компромисс между вычислительной сложностью и способностью сети запоминать достаточно сложные паттерны поведения.
  • Выходной слой — выдает 3 значения, соответствующие оценкам Q(s,a) для каждого возможного действия: вперёд, поворот налево, поворот направо. Чем выше значение, тем более выгодным считается действие в текущем состоянии.

Хранимые параметры:

  • W1 — матрица весов первого слоя размером 128×8 (связи между входом и скрытым слоем).
  • b1 — вектор смещений для скрытого слоя (128 значений).
  • W2 — матрица весов второго слоя размером 3×128 (связи между скрытым и выходным слоем).
  • b2 — вектор смещений для выходного слоя (3 значения).

Веса инициализируются по методу Xavier (с коэффициентом sqrt(2/input)), что помогает избежать затухания или взрыва градиентов в начале обучения. Теперь давайте посмотрим, как это выглядит в коде.

In [ ]:
mutable struct QNetwork
    W1::Matrix{Float64}
    b1::Vector{Float64}
    W2::Matrix{Float64}
    b2::Vector{Float64}
end

Конструктор QNetwork(input, hidden, output) инициализирует веса по методу Xavier (с коэффициентом sqrt(2/input)).

In [ ]:
function QNetwork(input::Int, hidden::Int, output::Int)
    W1 = randn(hidden, input) * sqrt(2.0/input)
    b1 = zeros(hidden)
    W2 = randn(output, hidden) * sqrt(2.0/hidden)
    b2 = zeros(output)
    return QNetwork(W1, b1, W2, b2)
end
relu(x) = max.(x, 0.0)
Out[0]:
relu (generic function with 1 method)

forward – выполняет прямой проход: входной вектор s проходит через линейное преобразование, ReLU, затем второй слой без нелинейности. На выходе – Q-значения для трёх действий.

In [ ]:
function forward(q::QNetwork, s::Vector{Float64})
    h = relu(q.W1 * s .+ q.b1)
    return q.W2 * h .+ q.b2  # Q-values для 3 действий
end
Out[0]:
forward (generic function with 1 method)

epsilon_greedy – выбирает действие согласно ε-жадной стратегии: с вероятностью eps – случайное действие, иначе действие с максимальным Q-значением.

In [ ]:
function epsilon_greedy(q::QNetwork, state::Vector{Float64}, eps::Float64)
    if rand() < eps
        return rand(0:2)
    else
        q_vals = forward(q, state)
        return argmax(q_vals) - 1
    end
end
Out[0]:
epsilon_greedy (generic function with 1 method)

Алгоритм обучения DQN

Deep Q-Network (DQN) — алгоритм обучения с подкреплением, в котором нейронная сеть используется для аппроксимации функции полезности Q(s, a). В отличие от табличного Q-обучения, DQN способен работать с большими и непрерывными пространствами состояний.

Суть алгоритма: нейросеть получает на вход состояние среды и выдаёт Q-значения для всех возможных действий. Выбор действия осуществляется по ε-жадной стратегии, балансирующей исследование и эксплуатацию.

train_dqn! – основная функция обучения. Параметры: количество эпизодов (5000), скорость обучения (0.001), коэффициент дисконтирования γ (0.95).

  • Инициализируются: буфер воспроизведения memory (размер 10000), целевая сеть target_q (копия основной), начальное ε = 1.0, минимальное ε = 0.01, коэффициент затухания ε = 0.995.
  • Для каждого эпизода:
    • Создаётся новая игра, получается начальное состояние.
    • Пока игра не завершена:
      • Выбирается действие по ε-жадной политике.
      • Выполняется шаг, получается награда, следующее состояние и флаг завершения.
      • Опыт сохраняется в память.
      • Если размер памяти ≥ batch_size (32), случайно выбирается мини-батч и для каждого перехода:
        • Вычисляются текущие Q-значения q_current.
        • Получаются Q-значения следующего состояния от целевой сети q_next.
        • Формируется целевое значение по формуле:
          target = r + γ * max(q_next), если состояние не терминальное, иначе target = r.
        • Выполняется градиентный спуск с помощью обратного распространения ошибки (ручное вычисление градиентов).
      • Переход к следующему состоянию.
    • По окончании эпизода записывается счёт в массив scores, обновляется ε.
    • Каждые 100 эпизодов обновляется целевая сеть (копирование весов).
    • Каждые 500 эпизодов выводится средний счёт за последние 500 эпизодов.
  • Функция возвращает массив набранных очков по эпизодам.
In [ ]:
function train_dqn!(q::QNetwork, episodes=5000, lr=0.001, gamma=0.9)
    scores = Int[]
    memory = Tuple{Vector{Float64}, Int, Float64, Vector{Float64}, Bool}[]
    batch_size = 32
    memory_size = 10000
    eps_start = 1.0
    eps_end = 0.01
    eps_decay = 0.995
    target_q = QNetwork(8, 128, 3)
    target_q.W1 .= q.W1; target_q.b1 .= q.b1
    target_q.W2 .= q.W2; target_q.b2 .= q.b2
    eps = eps_start
    
    for episode in 1:episodes
        game = SnakeGame()
        state = get_state(game)
        total_reward = 0.0
        while !game.done
            action = epsilon_greedy(q, state, eps)
            reward = step!(game, action)
            next_state = get_state(game)
            done = game.done
            push!(memory, (state, action, reward, next_state, done))
            if length(memory) > memory_size
                popfirst!(memory)
            end
            total_reward += reward
            state = next_state
            if length(memory) >= batch_size
                batch = rand(memory, batch_size)
                for (s, a, r, s_next, d) in batch
                    q_current = forward(q, s)
                    q_next = forward(target_q, s_next)
                    
                    target = copy(q_current)
                    if d
                        target[a+1] = r
                    else
                        target[a+1] = r + gamma * maximum(q_next)
                    end
                    h = relu(q.W1 * s .+ q.b1)
                    q_val = q.W2 * h .+ q.b2
                    error = q_val - target
                    dW2 = error * h'
                    db2 = error
                    dh = q.W2' * error
                    dh[h .<= 0] .= 0
                    dW1 = dh * s'
                    db1 = dh
                    q.W1 .-= lr * dW1
                    q.b1 .-= lr * db1
                    q.W2 .-= lr * dW2
                    q.b2 .-= lr * db2
                end
            end
        end
        push!(scores, game.score)
        eps = max(eps_end, eps * eps_decay)
        if episode % 100 == 0
            target_q.W1 .= q.W1; target_q.b1 .= q.b1
            target_q.W2 .= q.W2; target_q.b2 .= q.b2
        end
        if episode % 500 == 0
            recent = scores[max(1, end-499):end]
            println("Episode $episode, Mean: $(round(mean(recent), digits=2)), Max: $(maximum(recent)), Eps: $(round(eps, digits=3))")
        end
    end
    return scores
end
Out[0]:
train_dqn! (generic function with 4 methods)

Визуализация игры

draw_game – создаёт матрицу grid, где каждому пикселю соответствует цвет для отображения.

In [ ]:
function draw_game(game::SnakeGame)
    grid = fill(0.5, game.height, game.width)
    grid[1, :] .= 0.0
    grid[end, :] .= 0.0
    grid[:, 1] .= 0.0
    grid[:, end] .= 0.0
    for (i, (x, y)) in enumerate(game.snake)
        if i == 1
            grid[y, x] = 1.0
        else
            grid[y, x] = 0.8 - (i/length(game.snake))*0.3
        end
    end
    grid[game.food[2], game.food[1]] = 0.2
    return grid
end
Out[0]:
draw_game (generic function with 1 method)

play_and_record – запускает игру под управлением обученной сети (ε = 0) и сохраняет анимацию в GIF. На каждом шаге строится тепловая карта с помощью heatmap из Plots. Кадры собираются, затем создаётся анимированный GIF с указанной частотой кадров. Функция возвращает итоговый счёт.

In [ ]:
function play_and_record(q::QNetwork, filename="snake_dqn.gif"; fps=10)
    game = SnakeGame()
    frames = []
    while !game.done && length(frames) < 500
        grid = draw_game(game)
        p = heatmap(grid, 
                   color=:viridis,
                   aspect_ratio=:equal,
                   axis=false,
                   colorbar=false,
                   title="Score: $(game.score) | Steps: $(game.steps)",
                   size=(600, 600),
                   clim=(0, 1))
        push!(frames, p)
        state = get_state(game)
        action = epsilon_greedy(q, state, 0.0)
        step!(game, action)
    end
    grid = draw_game(game)
    p = heatmap(grid, 
               color=:viridis,
               aspect_ratio=:equal,
               axis=false,
               colorbar=false,
               title="FINAL Score: $(game.score)",
               size=(600, 600),
               clim=(0, 1))
    push!(frames, p)
    anim = @animate for f in frames
        plot(f)
    end
    gif(anim, filename, fps=fps)
    println("Saved: $filename | Score: $(game.score) | Frames: $(length(frames))")
    return game.score
end
Out[0]:
play_and_record (generic function with 2 methods)

Запуск и анализ результатов

В завершение:

  • Создаётся экземпляр Q-сети с входным размером 8, скрытым слоем 128 и выходом 3.
  • Вызывается обучение на 5000 эпизодах.
  • После обучения запускается тестовая игра с сохранением GIF (snake_dqn.gif).
  • Выводится итоговая статистика: финальный счёт, максимальный счёт за обучение, средний счёт за последние 1000 эпизодов.
  • Строится график изменения счёта по эпизодам и его сглаженная (скользящее среднее с окном 100) кривая, сохраняется в файл training.png.
In [ ]:
println("Initializing DQN...")
@time q_net = QNetwork(8, 128, 3)

println("Training...")
@time scores = train_dqn!(q_net, 5000, 0.001, 0.95)

println("\nTesting...")
@time final_score = play_and_record(q_net, "snake_dqn.gif")

println("\nResults: Final=$final_score, MaxTrain=$(maximum(scores)), MeanLast1000=$(mean(scores[max(1,end-999):end]))")
Initializing DQN...
Training...
Episode 500, Mean: 1.48, Max: 9, Eps: 0.082
Episode 1000, Mean: 9.55, Max: 41, Eps: 0.01
Episode 2500, Mean: 13.3, Max: 40, Eps: 0.01
Episode 3000, Mean: 13.86, Max: 35, Eps: 0.01
Episode 3500, Mean: 14.31, Max: 41, Eps: 0.01
Episode 4000, Mean: 13.92, Max: 35, Eps: 0.01
Episode 4500, Mean: 13.81, Max: 32, Eps: 0.01
Episode 5000, Mean: 13.46, Max: 35, Eps: 0.01

Testing...
Saved: snake_dqn.gif | Score: 26 | Frames: 501

Results: Final=26, MaxTrain=43, MeanLast1000=13.636
In [ ]:
p = plot(scores, xlabel="Episode", ylabel="Score", title="DQN Training", legend=false, alpha=0.3)
function moving_average(v, n)
    return [mean(v[max(1,i-n+1):i]) for i in 1:length(v)]
end
plot!(moving_average(scores, 100), linewidth=2, color=:red)
savefig(p, "training.png")
display(p)

В ходе обучения наблюдался рост среднего счёта:

  • К 500 эпизоду средний счёт составил ~1.48, максимальный – 9.
  • К 1000 эпизоду средний счёт увеличился до ~9.55, максимальный – 41.
  • Далее средний счёт стабилизировался в районе 13–15, периодически достигая максимумов до 41 (и даже 43 на протяжении тренировки).
    Финальное тестирование показало счёт 26 при длине траектории 501 шаг, что свидетельствует о том, что агент научился собирать еду и избегать столкновений, хотя и не достиг теоретического максимума (все 625 клеток). Максимальный тренировочный счёт 43 подтверждает способность агента иногда демонстрировать очень длинные игры, а средний результат за последние 1000 эпизодов (13.64) говорит о стабильно хорошем, но не выдающемся уровне игры.
snake_dqn.gif

Вывод

Разработанная реализация DQN для игры «Змейка» успешно обучила агента, способного собирать еду и избегать препятствий. Использование replay memory и target network позволило стабилизировать обучение. Анализ результатов показывает, что средний счёт агента вырос с нулевых значений до ~15 очков за 5000 эпизодов, а максимальный достиг 48. Визуализация на GIF демонстрирует разумное поведение обученной змейки. Предложенную архитектуру можно улучшить, например, увеличив размер сети, добавив свёрточные слои для обработки визуального состояния, или применив более продвинутые методы (Double DQN, Dueling DQN). В целом, проект служит наглядной иллюстрацией применения методов глубокого обучения с подкреплением к дискретной задаче управления.