Сегментация изображений SAR
Введение
Этот демо-пример посвящён задаче сегментации кораблей на радиолокационных изображениях. Такие снимки отличаются высокой информативностью и позволяют наблюдать объекты независимо от времени суток и погодных условий, однако интерпретация данных вручную требует больших усилий и знаний. Использование нейросетевых методов позволяет существенно ускорить обработку и повысить точность обнаружения.
Архитектура U-Net с бэкбоном ResNet-18 выбрана благодаря её способности учитывать как локальные детали, так и глобальный контекст, что особенно важно при работе с морской поверхностью и небольшими по размеру объектами.
Импорт библиотек и настройка путей
using Pkg
# Pkg.add("DataLoaders")
using FileIO, Images
using ImageTransformations: imresize, Linear
using ProgressMeter, PyCall
using Printf, Dates
using CUDA
using Flux, Metalhead
using BSON: @save, @load
using Statistics, Printf
Зададим пути для тренировочных и тестовых данных: изображения и маски
train_dir_imgs = raw"data/train/imgs"
train_dir_masks = raw"data/train/masks"
test_dir_imgs = raw"data/test/imgs"
test_dir_masks = raw"data/test/masks";
Гиперпараметры обучения
Определяются размеры входа (256,256), размер батча, скорость обучения
res = (256, 256)
batch_size = 2
learning_rate = 1e-3
Датасет
Класс ImageDataset - структура для хранения путей к картинкам и маскам.
В методе getindex данные загружаются, масштабируются до нужного размера и преобразуются
function my_collate(samples)
xs = first.(samples)
ys = last.(samples)
return (cat(xs...; dims=4),
cat(ys...; dims=4))
end
struct ImageDataset
imgs::Vector{String}
masks::Vector{String}
end
ImageDataset(folder_imgs::String, folder_masks::String) =
ImageDataset(readdir(folder_imgs; join=true), readdir(folder_masks; join=true))
function Base.getindex(ds::ImageDataset, i::Int)
img = Gray.(load(ds.imgs[i]))
img = imresize(img, res; method=Linear())
x = Float32.(img)
H, W = size(x)
x = reshape(x, H, W, 1)
msk = Gray.(load(ds.masks[i]))
msk = imresize(msk, res)
m = Float32.(msk .> 0)
y = cat(1f0 .- m, m; dims=3)
return x, y
end
Base.length(ds::ImageDataset) = length(ds.imgs)
Создаются train_loader и test_loader, которые по батчам подают данные в модель
train_data = ImageDataset(train_dir_imgs, train_dir_masks)
test_data = ImageDataset(test_dir_imgs, test_dir_masks)
train_loader = Flux.DataLoader(train_data; batchsize=batch_size, collate=my_collate, parallel=false)
test_loader = Flux.DataLoader(test_data; batchsize=batch_size, collate=my_collate, parallel=false)
Визуализация данных
В ячейке ниже визуализируем данные для оценик того, с чем работаем
img, mask = train_data[2]
to_rgb(x) = Gray.(dropdims(x; dims=3))
rgb = to_rgb(img)
mask_vis = Gray.(mask[:, :, 2])
hcat(rgb, mask_vis)
Определение модели
Конструируется UNet с бэкбоном ResNet18, настраивается оптимизатор (Adam с weight decay), проверяется наличие GPU.
model = UNet(res, 1, 2, Metalhead.backbone(Metalhead.ResNet(18; inchannels=1)))
device = CUDA.functional() ? gpu : cpu
model = device(model)
θ = Flux.params(model)
opt = Flux.Optimiser(WeightDecay(1e-6), Adam(learning_rate))
Определение функций тренировки и валидации
train_step: делает прямой проход, считает logitcrossentropy, вычисляет градиенты и обновляет веса.
valid_step: оценивает loss на валидационных данных.
function train_step(model, θ, x, y, opt)
loss_cpu = 0f0
∇ = Flux.gradient(θ) do
ŷ = model(x)
l = Flux.logitcrossentropy(ŷ, y; dims=3)
loss_cpu = cpu(l)
l
end
Flux.Optimise.update!(opt, θ, ∇)
return loss_cpu
end
function valid_step(model, x, y)
ŷ = model(x)
l = Flux.logitcrossentropy(ŷ, y; dims=3)
return float(l)
end
Обучение модели
Тут описан основной цикл обучения модели. Изначально задаем нужное нам кол-во эпох
mkdir("model");
epochs = 50
for epoch in 1:epochs
println("epoch: ", epoch)
trainmode!(model)
train_loss = 0f0
for (x, y) in train_loader
train_loss += train_step(model, θ, device(x), device(y), opt)
end
train_loss /= length(train_loader)
@info "Epoch $epoch | Train Loss $train_loss"
testmode!(model)
validation_loss = 0f0
for (x, y) in test_loader
validation_loss += valid_step(model, device(x), device(y))
end
validation_loss /= length(test_loader)
@info "Epoch $epoch | Validation Loss $validation_loss"
# сохранение чекпоинта
fn = joinpath("model", @sprintf("model_epoch_%03d.bson", epoch))
@save fn model
@info " ↳ модель сохранена в $fn"
end
Сохранение финальной модели
После обучения модель переводится на CPU и сохраняется отдельно в файл model1.bson.
model = cpu(model)
best_path = joinpath("model", "model1.bson")
@info "Сохраняем модель в $best_path"
@save best_path model
Инференс
Вспомогательные функции для инференса
ship_probs: извлекает вероятность класса «корабль» из выхода сети.
predict_mask: по входному изображению возвращает бинарную маску или вероятностную карту.
save_prediction: сохраняет предсказанную маску как картинку.
scan_thresholds: помогает подобрать порог для бинаризации, выводя статистику по разным значениям.
model_on_gpu(m) = any(x -> x isa CuArray, Flux.params(m))
to_dev(x, m) = model_on_gpu(m) ? gpu(x) : x
function ship_probs(ŷ)
sz = size(ŷ)
if sz[3] == 2
p = softmax(ŷ; dims=3)
return Array(@view p[:, :, 2, 1])
else sz[1] == 2
p = softmax(ŷ; dims=1)
return Array(@view p[2, :, :, 1]) |> x -> permutedims(x, (2,1))
end
end
function predict_mask(model, img_path; thr=0.35, return_probs=false)
img = Gray.(load(img_path))
img = imresize(img, res; method=Linear())
x = Float32.(img)
H, W = size(x)
x = reshape(x, H, W, 1, 1)
ŷ = model(to_dev(x, model))
p_ship = ship_probs(ŷ)
@info "ship prob stats" min=minimum(p_ship) max=maximum(p_ship) mean=mean(p_ship)
mask = Float32.(p_ship .>= thr)
return return_probs ? (mask, p_ship) : mask
end
function save_prediction(model, in_path, out_path; thr=0.35)
m = predict_mask(model, in_path; thr=thr)
save(out_path, Gray.(m))
@info "saved" path=out_path positives=sum(m .> 0)
end
function scan_thresholds(model, img_path; ts=0.10:0.05:0.50)
_, p = predict_mask(model, img_path; return_probs=true)
for t in ts
m = p .>= t
@printf "thr=%.2f positives=%6d max=%.3f mean=%.3f\n" t count(m) maximum(p) mean(p)
end
end
Для выбранного SAR-снимка считается карта вероятностей, проводится подбор порога и сохраняется итоговая маска pred_mask.png.
img_path = raw"data/train/imgs/P0003_1200_2000_4200_5000.png"
scan_thresholds(model, img_path)
save_prediction(model, img_path, "pred_mask.png"; thr=0.50)
Выводы
В данной работе была обучения нейронная сеть UNet с бэкбоном в качестве ResNet для сегментации SAR изображений. Задача не самая простая, поскольку корабли на изображениях могут находится в городской среде. Для дальнейшего улучшения качества модели необходимо более тонко подойти к обучению