Сообщество Engee

Сегментация изображений SAR

Автор
avatar-aalexandrgorbunovaalexandrgorbunov
Notebook

Введение

Этот демо-пример посвящён задаче сегментации кораблей на радиолокационных изображениях. Такие снимки отличаются высокой информативностью и позволяют наблюдать объекты независимо от времени суток и погодных условий, однако интерпретация данных вручную требует больших усилий и знаний. Использование нейросетевых методов позволяет существенно ускорить обработку и повысить точность обнаружения.

Архитектура U-Net с бэкбоном ResNet-18 выбрана благодаря её способности учитывать как локальные детали, так и глобальный контекст, что особенно важно при работе с морской поверхностью и небольшими по размеру объектами.

Импорт библиотек и настройка путей

using Pkg
# Pkg.add("DataLoaders")
using FileIO, Images
using ImageTransformations: imresize, Linear
using ProgressMeter, PyCall
using Printf, Dates
using CUDA
using Flux, Metalhead
using BSON: @save, @load
using Statistics, Printf

Зададим пути для тренировочных и тестовых данных: изображения и маски

train_dir_imgs  = raw"data/train/imgs"
train_dir_masks = raw"data/train/masks"
test_dir_imgs   = raw"data/test/imgs"
test_dir_masks  = raw"data/test/masks";

Гиперпараметры обучения

Определяются размеры входа (256,256), размер батча, скорость обучения

res = (256, 256)          
batch_size = 2
learning_rate = 1e-3
0.001

Датасет

Класс ImageDataset - структура для хранения путей к картинкам и маскам.
В методе getindex данные загружаются, масштабируются до нужного размера и преобразуются

function my_collate(samples)
    xs = first.(samples)
    ys = last.(samples)
    return (cat(xs...; dims=4),  
            cat(ys...; dims=4))   
end
my_collate (generic function with 1 method)
struct ImageDataset
    imgs::Vector{String}
    masks::Vector{String}
end

ImageDataset(folder_imgs::String, folder_masks::String) =
    ImageDataset(readdir(folder_imgs; join=true), readdir(folder_masks; join=true))

function Base.getindex(ds::ImageDataset, i::Int)
    img = Gray.(load(ds.imgs[i]))                 
    img = imresize(img, res; method=Linear())   
    x = Float32.(img)                            
    H, W = size(x)
    x = reshape(x, H, W, 1)                      

    msk = Gray.(load(ds.masks[i]))               
    msk = imresize(msk, res) 
    m = Float32.(msk .> 0)                       


    y = cat(1f0 .- m, m; dims=3)

    return x, y
end

Base.length(ds::ImageDataset) = length(ds.imgs)

Создаются train_loader и test_loader, которые по батчам подают данные в модель

train_data = ImageDataset(train_dir_imgs, train_dir_masks)
test_data  = ImageDataset(test_dir_imgs,  test_dir_masks)

train_loader = Flux.DataLoader(train_data; batchsize=batch_size, collate=my_collate, parallel=false)
test_loader  = Flux.DataLoader(test_data;  batchsize=batch_size, collate=my_collate, parallel=false)
50-element DataLoader(::ImageDataset, batchsize=2, collate=my_collate)
  with first element:
  (256×256×1×2 Array{Float32, 4}, 256×256×2×2 Array{Float32, 4},)

Визуализация данных

В ячейке ниже визуализируем данные для оценик того, с чем работаем

img, mask = train_data[2]

to_rgb(x) = Gray.(dropdims(x; dims=3))

rgb = to_rgb(img)

mask_vis = Gray.(mask[:, :, 2])

hcat(rgb, mask_vis)
No description has been provided for this image

Определение модели

Конструируется UNet с бэкбоном ResNet18, настраивается оптимизатор (Adam с weight decay), проверяется наличие GPU.

model = UNet(res, 1, 2, Metalhead.backbone(Metalhead.ResNet(18; inchannels=1)))

device = CUDA.functional() ? gpu : cpu
model  = device(model)
θ      = Flux.params(model)

opt = Flux.Optimiser(WeightDecay(1e-6), Adam(learning_rate))
Flux.Optimise.Optimiser(Any[WeightDecay(1.0e-6), Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())])

Определение функций тренировки и валидации

train_step: делает прямой проход, считает logitcrossentropy, вычисляет градиенты и обновляет веса.

valid_step: оценивает loss на валидационных данных.

function train_step(model, θ, x, y, opt)
    loss_cpu = 0f0
     = Flux.gradient(θ) do
        ŷ = model(x)                              
        l = Flux.logitcrossentropy(ŷ, y; dims=3)
        loss_cpu = cpu(l)
        l
    end
    Flux.Optimise.update!(opt, θ, )
    return loss_cpu
end

function valid_step(model, x, y)
    ŷ = model(x)
    l = Flux.logitcrossentropy(ŷ, y; dims=3)
    return float(l)
end
valid_step (generic function with 1 method)

Обучение модели

Тут описан основной цикл обучения модели. Изначально задаем нужное нам кол-во эпох

mkdir("model");
epochs = 50
for epoch in 1:epochs
    println("epoch: ", epoch)
    trainmode!(model)

    train_loss = 0f0
    for (x, y) in train_loader
        train_loss += train_step(model, θ, device(x), device(y), opt)
    end
    train_loss /= length(train_loader)
    @info "Epoch $epoch | Train Loss $train_loss"

    testmode!(model)
    validation_loss = 0f0
    for (x, y) in test_loader
        validation_loss += valid_step(model, device(x), device(y))
    end
    validation_loss /= length(test_loader)
    @info "Epoch $epoch | Validation Loss $validation_loss"

    # сохранение чекпоинта
    fn = joinpath("model", @sprintf("model_epoch_%03d.bson", epoch))
    @save fn model
    @info "  ↳ модель сохранена в $fn"
end

Сохранение финальной модели

После обучения модель переводится на CPU и сохраняется отдельно в файл model1.bson.

model = cpu(model)
best_path = joinpath("model", "model1.bson")
@info "Сохраняем модель в $best_path"
@save best_path model
[ Info: Сохраняем модель в model/model1.bson

Инференс

Вспомогательные функции для инференса

ship_probs: извлекает вероятность класса «корабль» из выхода сети.

predict_mask: по входному изображению возвращает бинарную маску или вероятностную карту.

save_prediction: сохраняет предсказанную маску как картинку.

scan_thresholds: помогает подобрать порог для бинаризации, выводя статистику по разным значениям.

model_on_gpu(m) = any(x -> x isa CuArray, Flux.params(m))
to_dev(x, m)    = model_on_gpu(m) ? gpu(x) : x

function ship_probs()
    sz = size()
    if sz[3] == 2                
        p = softmax(; dims=3)
        return Array(@view p[:, :, 2, 1])
    else sz[1] == 2          
        p = softmax(; dims=1)
        return Array(@view p[2, :, :, 1]) |> x -> permutedims(x, (2,1))
    end
end


function predict_mask(model, img_path; thr=0.35, return_probs=false)
    img = Gray.(load(img_path))
    img = imresize(img, res; method=Linear())
    x = Float32.(img)
    H, W = size(x)
    x = reshape(x, H, W, 1, 1)

     = model(to_dev(x, model))

    p_ship = ship_probs()

    @info "ship prob stats" min=minimum(p_ship) max=maximum(p_ship) mean=mean(p_ship)

    mask = Float32.(p_ship .>= thr)

    return return_probs ? (mask, p_ship) : mask
end

function save_prediction(model, in_path, out_path; thr=0.35)
    m = predict_mask(model, in_path; thr=thr)
    save(out_path, Gray.(m))
    @info "saved" path=out_path positives=sum(m .> 0)
end

function scan_thresholds(model, img_path; ts=0.10:0.05:0.50)
    _, p = predict_mask(model, img_path; return_probs=true)
    for t in ts
        m = p .>= t
        @printf "thr=%.2f  positives=%6d  max=%.3f  mean=%.3f\n" t count(m) maximum(p) mean(p)
    end
end
scan_thresholds (generic function with 1 method)

Для выбранного SAR-снимка считается карта вероятностей, проводится подбор порога и сохраняется итоговая маска pred_mask.png.

img_path = raw"data/train/imgs/P0003_1200_2000_4200_5000.png"

scan_thresholds(model, img_path)


save_prediction(model, img_path, "pred_mask.png"; thr=0.50)

Выводы

В данной работе была обучения нейронная сеть UNet с бэкбоном в качестве ResNet для сегментации SAR изображений. Задача не самая простая, поскольку корабли на изображениях могут находится в городской среде. Для дальнейшего улучшения качества модели необходимо более тонко подойти к обучению