Engee 文档
Notebook

使用预训练的 ResNet 神经网络进行图像分类

神经网络是一种方便灵活的算法,可以通过计算进行训练。通过成功的训练过程,同一个神经网络可用于多种不同的任务。例如,著名的神经网络系列ResNet 就是为图像分类而创建的。从图形数据库到风格转换,此类神经网络可部分或全部用于需要处理图像数字表示的各种任务中。

在本示例中,我们将运行一个深度较浅(18 层)的神经网络ResNet 进行图像分类。我们将展示数据准备和输出处理的整个链条,对于任何图像,它都将为我们提供一个或多或少合适的文本标签,以描述图像中描绘的对象。

准备工作

目前,Umlaut 库的Tape 对象是主要的计算容器,在这里可以解压ONNX 格式的神经网络。由于ONNX 库正在更新过程中,这种启动机制可能会在不久的将来改变。

In [ ]:
Pkg.add(["Umlaut", "ONNX"])
In [ ]:
 import Pkg; Pkg.add("Umlaut", io=devnull)
 import Umlaut: Tape, play!

当然,我们还需要 Engee 内置的用于处理 ONNX 格式(通常存储预训练神经网络)的库,以及用于处理图像的库。

In [ ]:
using ONNX
using Images
[ Info: Precompiling ONNX [d0dd6a25-fac6-55c0-abf7-829e0c774d20]

让我们安装工作文件夹

In [ ]:
cd( @__DIR__ )

中的对象类ImageNet

我们预先训练好的神经网络可以识别的所有对象类别的名称以有序向量的形式表示,并从以下文件中加载。

In [ ]:
include( "imagenet_classes.jl" );

输入输出函数

让我们编写三个简单的辅助函数,将数据输入神经网络并处理结果:

1.加载图像:我们将图像缩放至 244*244 的大小(在保持长宽比的前提下进行归一化和边缘裁剪将是一个受欢迎的附加功能) 2.排序预测:神经网络会返回一个数字向量,表示图像中出现 ImageNet 数据集中特定类别的概率。让我们选择最有可能出现的类别k 。 3.这些函数的外壳允许您加载图像,并在一次操作中输出k 最有可能的预测结果

In [ ]:
# Загрузка изображения из файла
function imread(path::AbstractString; sz=(224,224))
    img = Images.load(path);
    img = imresize(img, sz);
    x = convert(Array{Float32}, channelview(img))
    # Заменим порядок слоев: CHW -> WHC
    x = permutedims(x, (3, 2, 1))
    return x
end

# Выдача индексов первых k предсказаний
function maxk(a, k)
    b = partialsortperm(a, 1:k, rev=true)
    return collect(zip(b, a[b]))
end

# Загрузка изображения и выдача десяти наиболее вероятных классов в убывающем порядке
function test_image(tape::Tape, path::AbstractString)
    x = imread(path)
    x = reshape(x, size(x)..., 1)
    y = play!(tape, x)
    y = reshape(y, size(y, 1))
    top = maxk(y, 10)
    classes = []
    for (i, (idx, val)) in enumerate(top)
        name = IMAGENET_CLASSES[idx - 1]
        classes = [classes; "$i: $name ($val)"]
    end
    return join(classes, "\n")
end
Out[0]:
test_image (generic function with 1 method)

下载神经网络ResNet18

我们将使用一个神经网络,它位于扩展名为*.onnx 的文件中。有一些库可以让你使用更高级的命令来创建和加载这个神经网络(例如FluxML 系列中的 Metalhead.jl 库),但现在我们不需要额外的库。

预训练的神经网络已经在指定目录中,因此命令无需重新下载即可执行。

In [ ]:
path = "resnet18.onnx"

if !isfile(path)
    download("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx", path)
end
In [ ]:
# Создадим пустую матрицу на месте которой будет входное изображение
img = rand( Float32, 224, 224, 3, 1 )

# Загружаем модель в виде объекта Umlaut.Tape
resnet = ONNX.load( path, img );

让我们下载一些图像

如果已经下载过,同样不会再下载

In [ ]:
path = "data/"

goose_path = download( "https://upload.wikimedia.org/wikipedia/commons/3/3f/Snow_goose_2.jpg", path*"goose.jpg");
dog_path = download( "https://farm4.staticflickr.com/1301/4694470234_6f27a4f602_o.jpg", path*"dog.jpg");
plane_path = download( "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c9/Rossiya%2C_RA-89043%2C_Sukhoi_Superjet_100-95B_%2851271265892%29.jpg/1024px-Rossiya%2C_RA-89043%2C_Sukhoi_Superjet_100-95B_%2851271265892%29.jpg", path*"plane.jpg");

图像分类

In [ ]:
display( load(plane_path)[1:5:end, 1:5:end] )
print( test_image( resnet, plane_path ))
No description has been provided for this image
1: airliner (16.3932)
2: wing (12.387247)
3: warplane, military plane (11.546208)
4: airship, dirigible (10.379374)
5: space shuttle (9.501989)
6: missile (9.399774)
7: projectile, missile (8.82921)
8: tiger shark, Galeocerdo cuvieri (7.278539)
9: aircraft carrier, carrier, flattop, attack aircraft carrier (6.265907)
10: can opener, tin opener (6.131121)
In [ ]:
display( load(goose_path)[1:5:end, 1:5:end] )
print( test_image( resnet, goose_path ))
No description has been provided for this image
1: goose (14.927246)
2: crane (11.862924)
3: flamingo (11.377807)
4: spoonbill (11.055638)
5: white stork, Ciconia ciconia (10.838624)
6: American egret, great white heron, Egretta albus (10.136589)
7: pelican (10.013963)
8: bustard (9.504973)
9: peacock (9.44741)
10: albatross, mollymawk (8.912545)
In [ ]:
display( load(dog_path)[1:5:end, 1:5:end] )
print( test_image( resnet, dog_path ))
No description has been provided for this image
1: Pembroke, Pembroke Welsh corgi (16.254753)
2: Cardigan, Cardigan Welsh corgi (14.028244)
3: collie (11.081062)
4: golden retriever (10.7889805)
5: dingo, warrigal, warragal, Canis dingo (10.606851)
6: basenji (10.365379)
7: Shetland sheepdog, Shetland sheep dog, Shetland (9.810743)
8: beagle (9.483083)
9: Labrador retriever (9.437691)
10: Eskimo dog, husky (9.249486)

结论

我们已经证明,在 Engee 中下载一个神经网络并用它进行计算并不困难。

通过这种机制,我们可以组织由高级组件组成的复杂信息处理流水线。特别是,我们可以将信息处理的某些阶段委托给预先训练好的神经网络。