Документация Engee

Правила участия в разработке

Самый простой способ принять участие в разработке этого пакета — добавить новый оператор ONNX. Для этого сделайте следующее.

  1. Добавьте новый метод в load_node().

  2. Добавьте новый метод в save_node!().

  3. Напишите тесты.

Добавление нового метода в load_node()

Функция load_node() загружает один оператор ONNX из графа в Umlaut.Tape. Она имеет следующую сигнатуру:

load_node!(tape::Tape, ::OpConfig{BE, Op}, args::VarVec, attrs::AttrDict)

Где:

  • Umlaut.Tape представляет вычислительный граф в Julia;

  • OpConfig{BE, Op} служит для диспетчеризации в бэкенд BE и оператор Op;

  • VarVec (псевдоним Vector{Umlaut.Variable}) — это список входных переменных оператора;

  • AttrDict (псевдоним Dict{Symbol, Any}) — это словарь атрибутов оператора ONNX.

Рассмотрим пример:

function load_node!(tape::Tape, ::OpConfig{:ONNX, :Relu}, args::VarVec, attrs::AttrDict)
    return push_call!(tape, NNlib.relu, args[1])
end

Здесь оператор Relu преобразуется в отдельный вызов NNlib.relu. Обе реализации — в ONNX и в NNlib — принимают один аргумент, который передается в вызов. Обратите внимание, что args[1] ссылается на переменную на ленте в формате Julia (по столбцам).

Более сложный пример — оператор Gemm:

function load_node!(tape::Tape, ::OpConfig{:ONNX, :Gemm}, args::VarVec, attrs::AttrDict)
    if (length(args) == 2 && get(attrs, :alpha, 1) == 1 &&
        get(attrs, :transA, 0) == 0 && get(attrs, :transB, 0) == 0)
        # упрощенная версия: только умножение матриц
        # примечание: аргументы меняются местами для учета построчного порядка массивов
        return push_call!(tape, *, args[2], args[1])
    else
        # полная версия GEMM
        kw = rename_keys(attrs, Dict(
            :transA => :tA,
            :transB => :tB,
            :alpha => :α,
            :beta => :β
        ))
        return push_call!(tape, onnx_gemm, args...; kw...)
    end
end

Здесь есть ряд сложностей. Во-первых, логика разделяется на два пути: простые случаи преобразуются просто в *, а для более сложных случаев реализуется собственная функция onnx_gemm. Во-вторых, в случае onnx_gemm атрибуты ONNX также преобразуются в именованные аргументы функции. В-третьих, в случае * аргументы меняются местами. Это немного необычно, однако необходимо для учета различия между массивами в ONNX с построчным порядком и массивами в Julia с порядком по столбцам: ONNX.jl автоматически обращает измерения массивов параметров при считывании данных из файлов .onnx и следует привычному для Julia порядку во время загрузки, однако в операторах все же могут требоваться некоторые поправки. Естественно, такие случаи следует тщательно продумывать и тестировать.

Другие примеры можно найти в save.jl.

Добавление нового метода в save_node!()

save_node!() — это функция, обратная load_node(). save_node принимает Umlaut.Call и добавляет соответствующие операторы в граф ONNX. Ее сигнатура выглядит следующим образом:

save_node!(g::GraphProto, ::OpConfig{BE, Fn}, op::Umlaut.Call)

Где:

  • GraphProto — это структура данных ONNX, представляющая фактический вычислительный граф;

  • OpConfig{BE, Fn} служит для диспетчеризации в бэкенд BE и тип функции Julia Fn;

  • Umlaut.Call представляет отдельный вызов f::Fn для Tape.

Пример:

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(relu)}, op::Umlaut.Call)
    nd = NodeProto("Relu", op)
    push!(g.node, nd)
end

NodeProto(op_type::String, op::Umlaut.Call, attrs::Dict=Dict()) — это удобный конструктор, который создает узел ONNX указанного типа и сопоставляет аргументы функции Julia (типа Umlaut.Variable) с именами соответствующих аргументов в уже построенном графе ONNX. Для большей части операторов этого достаточно.

Рассмотрим еще один пример:

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
    nd = NodeProto(
        input=[onnx_name(v) for v in reverse(op.args)],
        output=[onnx_name(op)],
        name=onnx_name(op),
        attribute=AttributeProto[],
        op_type="Gemm"
    )
    push!(g.node, nd)
end

В приведенной выше функции load_node() порядок следования аргументов был обращен. При сохранении узла нужно сделать то же самое. Таким образом, необходимо создать NodeProto вручную. onnx_name(v) генерирует допустимое имя ONNX на основе переменной. Остальной код должен быть понятен без пояснений.

Вот также функция save_node!() для версии onnx_gemm:

function save_node!(g::GraphProto, ::@opconfig_kw(:ONNX, onnx_gather), op::Umlaut.Call)
    data = iskwfunc(op.fn) ? op.args[3]._op.val : op.args[1]._op.val
    kw_dict = kwargs2dict(op)
    dim = get(kw_dict, :dim, ndims(data))
    axis = ndims(data) - dim
    nd = NodeProto("Gather", op, Dict(:axis => axis))
    push!(g.node, nd)
end

Обратите внимание, что в этом фрагменте кода вместо типа OpConfig{...} используется макрос @opconfig_kw(...). Этот макрос развертывается в определение, которое перехватывает как обычную версию функции, так и версию kw:

OpConfig{:ONNX, <:Union{typeof(onnx_gemm), typeof(Core.kwfunc(onnx_gemm))}}

Другие примеры можно найти в save.jl.

Тестирование

ort_test() принимает функцию Julia, создает объект Tape и сохраняет его в виде файла .onnx, после чего использует ONNXRunTime.jl для его выполнения и обратной загрузки. Использовать эту функцию очень просто:

x = rand(3, 4)
ort_test(ONNX.relu, x)