Документация Engee

Использование предобученной нейросети ResNet для классификации изображений

Нейросети – удобный гибкий алгоритм, который можно обучить вычислениям. При успешно выстроенном процессе обучения, одну и ту же нейросеть можно будет использовать во множестве разных задач. Этим, например, знаменито семейство нейросетей ResNet, созданные для классификации изображений. Частично или целиком, подобные нейросети используются в самых разных задачах, где нужно работать с численными репрезентациями изображений, от графических баз данных до переноса стиля.

В этом примере мы запустим нейросеть ResNet небольшой глубины (18 слоёв) для классификации изображений. Покажем всю цепочку подготовки данных и обработки выходной информации, которая для любой картинки выдаст нам более или менее подходящую текстовую метку, характеризующую объект, изображенный на картинке.

Подготовительная работа

Объект Tape библиотеки Umlaut, на данный момент, является основным контейнером для вычислений, куда можно распаковать нейросеть из формата ONNX. Этот механизм запуска, скорее всего, будет изменен в ближайшем будущем, поскольку библиотека ONNX сейчас находится в процессе обновления.

 import Pkg; Pkg.add("Umlaut", io=devnull)
 import Umlaut: Tape, play!

Конечно, нам также понадобятся встроенные в Engee библиотеки для работы с форматом ONNX, в котором часто хранятся предобученные нейросети, и библиотека для работы с изображениями.

using ONNX
using Images
[ Info: Precompiling ONNX [d0dd6a25-fac6-55c0-abf7-829e0c774d20]

Установим рабочую папку

cd( @__DIR__ )

Классы объектов в ImageNet

Названия всех классов объектов, которые умеет распознавать наша предобученная нейросеть, представлены в качестве упорядоченного вектора и загружаются из следующего файла.

include( "imagenet_classes.jl" );

Функции ввода-вывода

Напишем три простые вспомогательные функции для подачи данных в нейросеть и обработки результатов:

  1. Загрузка изображения: мы масштабируем его до размера 244*244 (нормализация и обрезка краев с сохранением пропорций были бы желательным дополнением)

  2. Сортировка предсказаний: нейросеть возвращает вектор чисел, которые означают вероятность того, что на изображении наблюдается тот или иной класс из датасета ImageNet. Отберем k наиболее вероятно представленных классов

  3. Оболочка для этих функций позволяет за одно действие загрузить изображение и выдать k наиболее вероятных предсказаний

# Загрузка изображения из файла
function imread(path::AbstractString; sz=(224,224))
    img = Images.load(path);
    img = imresize(img, sz);
    x = convert(Array{Float32}, channelview(img))
    # Заменим порядок слоев: CHW -> WHC
    x = permutedims(x, (3, 2, 1))
    return x
end

# Выдача индексов первых k предсказаний
function maxk(a, k)
    b = partialsortperm(a, 1:k, rev=true)
    return collect(zip(b, a[b]))
end

# Загрузка изображения и выдача десяти наиболее вероятных классов в убывающем порядке
function test_image(tape::Tape, path::AbstractString)
    x = imread(path)
    x = reshape(x, size(x)..., 1)
    y = play!(tape, x)
    y = reshape(y, size(y, 1))
    top = maxk(y, 10)
    classes = []
    for (i, (idx, val)) in enumerate(top)
        name = IMAGENET_CLASSES[idx - 1]
        classes = [classes; "$i: $name ($val)"]
    end
    return join(classes, "\n")
end
test_image (generic function with 1 method)

Скачаем нейросеть ResNet18

Мы будем пользоваться нейросетью, которая лежит в файле с расширением *.onnx. Есть библиотеки, которые позволяют создать и загрузить эту нейросеть при помощи еще более высокоуровневых команд (например, библиотека Metalhead.jl из коллекции FluxML), но пока мы это сделаем без дополнительных библиотек.

Предобученная нейросеть уже находится в указанном каталоге, поэтому команда выполнится без повторного скачивания.

path = "resnet18.onnx"

if !isfile(path)
    download("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx", path)
end
# Создадим пустую матрицу на месте которой будет входное изображение
img = rand( Float32, 224, 224, 3, 1 )

# Загружаем модель в виде объекта Umlaut.Tape
resnet = ONNX.load( path, img );

Загрузим несколько изображений

Если они уже загружены, аналогично, повторно они скачаны не будут

path = "data/"

goose_path = download( "https://upload.wikimedia.org/wikipedia/commons/3/3f/Snow_goose_2.jpg", path*"goose.jpg");
dog_path = download( "https://farm4.staticflickr.com/1301/4694470234_6f27a4f602_o.jpg", path*"dog.jpg");
plane_path = download( "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c9/Rossiya%2C_RA-89043%2C_Sukhoi_Superjet_100-95B_%2851271265892%29.jpg/1024px-Rossiya%2C_RA-89043%2C_Sukhoi_Superjet_100-95B_%2851271265892%29.jpg", path*"plane.jpg");

Классифицируем изображения

display( load(plane_path)[1:5:end, 1:5:end] )
print( test_image( resnet, plane_path ))

interactive-scripts/images/image_processing_resnet_classification/1af985c01ee23aee6350d571424a7585247b6d2a

1: airliner (16.3932)
2: wing (12.387247)
3: warplane, military plane (11.546208)
4: airship, dirigible (10.379374)
5: space shuttle (9.501989)
6: missile (9.399774)
7: projectile, missile (8.82921)
8: tiger shark, Galeocerdo cuvieri (7.278539)
9: aircraft carrier, carrier, flattop, attack aircraft carrier (6.265907)
10: can opener, tin opener (6.131121)
display( load(goose_path)[1:5:end, 1:5:end] )
print( test_image( resnet, goose_path ))

interactive-scripts/images/image_processing_resnet_classification/7b1a0c15c9aec428056ffc23425b6c30dd663b06

1: goose (14.927246)
2: crane (11.862924)
3: flamingo (11.377807)
4: spoonbill (11.055638)
5: white stork, Ciconia ciconia (10.838624)
6: American egret, great white heron, Egretta albus (10.136589)
7: pelican (10.013963)
8: bustard (9.504973)
9: peacock (9.44741)
10: albatross, mollymawk (8.912545)
display( load(dog_path)[1:5:end, 1:5:end] )
print( test_image( resnet, dog_path ))

interactive-scripts/images/image_processing_resnet_classification/3f6d2e08611dc69b8ce228d7e97ca2f2577534ef

1: Pembroke, Pembroke Welsh corgi (16.254753)
2: Cardigan, Cardigan Welsh corgi (14.028244)
3: collie (11.081062)
4: golden retriever (10.7889805)
5: dingo, warrigal, warragal, Canis dingo (10.606851)
6: basenji (10.365379)
7: Shetland sheepdog, Shetland sheep dog, Shetland (9.810743)
8: beagle (9.483083)
9: Labrador retriever (9.437691)
10: Eskimo dog, husky (9.249486)

Заключение

Мы показали, что в Engee несложно скачать нейросеть и выполнить с ее помощью вычисления.

Этот механизм позволяет организовать сложный конвейер обработки информации, состоящий из высокоуровневых компонентов. В частности, доверить некотореы этапы обработки информации предобученным нейросетям.

Библиография: https://github.com/FluxML/ONNX.jl