Правила участия в разработке
Самый простой способ принять участие в разработке этого пакета — добавить новый оператор ONNX. Для этого сделайте следующее.
-
Добавьте новый метод в
load_node()
. -
Добавьте новый метод в
save_node!()
. -
Напишите тесты.
Добавление нового метода в load_node()
Функция load_node()
загружает один оператор ONNX из графа в Umlaut.Tape
. Она имеет следующую сигнатуру:
load_node!(tape::Tape, ::OpConfig{BE, Op}, args::VarVec, attrs::AttrDict)
Где:
-
Umlaut.Tape
представляет вычислительный граф в Julia; -
OpConfig{BE, Op}
служит для диспетчеризации в бэкендBE
и операторOp
; -
VarVec
(псевдонимVector{Umlaut.Variable}
) — это список входных переменных оператора; -
AttrDict
(псевдонимDict{Symbol, Any}
) — это словарь атрибутов оператора ONNX.
Рассмотрим пример:
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Relu}, args::VarVec, attrs::AttrDict)
return push_call!(tape, NNlib.relu, args[1])
end
Здесь оператор Relu преобразуется в отдельный вызов NNlib.relu
. Обе реализации — в ONNX и в NNlib — принимают один аргумент, который передается в вызов. Обратите внимание, что args[1]
ссылается на переменную на ленте в формате Julia (по столбцам).
Более сложный пример — оператор Gemm:
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Gemm}, args::VarVec, attrs::AttrDict)
if (length(args) == 2 && get(attrs, :alpha, 1) == 1 &&
get(attrs, :transA, 0) == 0 && get(attrs, :transB, 0) == 0)
# упрощенная версия: только умножение матриц
# примечание: аргументы меняются местами для учета построчного порядка массивов
return push_call!(tape, *, args[2], args[1])
else
# полная версия GEMM
kw = rename_keys(attrs, Dict(
:transA => :tA,
:transB => :tB,
:alpha => :α,
:beta => :β
))
return push_call!(tape, onnx_gemm, args...; kw...)
end
end
Здесь есть ряд сложностей. Во-первых, логика разделяется на два пути: простые случаи преобразуются просто в *
, а для более сложных случаев реализуется собственная функция onnx_gemm
. Во-вторых, в случае onnx_gemm
атрибуты ONNX также преобразуются в именованные аргументы функции. В-третьих, в случае *
аргументы меняются местами. Это немного необычно, однако необходимо для учета различия между массивами в ONNX с построчным порядком и массивами в Julia с порядком по столбцам: ONNX.jl автоматически обращает измерения массивов параметров при считывании данных из файлов .onnx
и следует привычному для Julia порядку во время загрузки, однако в операторах все же могут требоваться некоторые поправки. Естественно, такие случаи следует тщательно продумывать и тестировать.
Другие примеры можно найти в save.jl.
Добавление нового метода в save_node!()
save_node!()
— это функция, обратная load_node()
. save_node
принимает Umlaut.Call
и добавляет соответствующие операторы в граф ONNX. Ее сигнатура выглядит следующим образом:
save_node!(g::GraphProto, ::OpConfig{BE, Fn}, op::Umlaut.Call)
Где:
-
GraphProto
— это структура данных ONNX, представляющая фактический вычислительный граф; -
OpConfig{BE, Fn}
служит для диспетчеризации в бэкендBE
и тип функции JuliaFn
; -
Umlaut.Call
представляет отдельный вызовf::Fn
дляTape
.
Пример:
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(relu)}, op::Umlaut.Call)
nd = NodeProto("Relu", op)
push!(g.node, nd)
end
NodeProto(op_type::String, op::Umlaut.Call, attrs::Dict=Dict())
— это удобный конструктор, который создает узел ONNX указанного типа и сопоставляет аргументы функции Julia (типа Umlaut.Variable
) с именами соответствующих аргументов в уже построенном графе ONNX. Для большей части операторов этого достаточно.
Рассмотрим еще один пример:
function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
nd = NodeProto(
input=[onnx_name(v) for v in reverse(op.args)],
output=[onnx_name(op)],
name=onnx_name(op),
attribute=AttributeProto[],
op_type="Gemm"
)
push!(g.node, nd)
end
В приведенной выше функции load_node()
порядок следования аргументов был обращен. При сохранении узла нужно сделать то же самое. Таким образом, необходимо создать NodeProto
вручную. onnx_name(v)
генерирует допустимое имя ONNX на основе переменной. Остальной код должен быть понятен без пояснений.
Вот также функция save_node!()
для версии onnx_gemm
:
function save_node!(g::GraphProto, ::@opconfig_kw(:ONNX, onnx_gather), op::Umlaut.Call)
data = iskwfunc(op.fn) ? op.args[3]._op.val : op.args[1]._op.val
kw_dict = kwargs2dict(op)
dim = get(kw_dict, :dim, ndims(data))
axis = ndims(data) - dim
nd = NodeProto("Gather", op, Dict(:axis => axis))
push!(g.node, nd)
end
Обратите внимание, что в этом фрагменте кода вместо типа OpConfig{...}
используется макрос @opconfig_kw(...)
. Этот макрос развертывается в определение, которое перехватывает как обычную версию функции, так и версию kw:
OpConfig{:ONNX, <:Union{typeof(onnx_gemm), typeof(Core.kwfunc(onnx_gemm))}}
Другие примеры можно найти в save.jl.