Recursive transformations from Functors.jl
Flux models are deeply nested structures, and Functors.jl provides tools needed to explore such objects, apply functions to the parameters they contain, and re-build them.
New layers should be annotated using the Functors.@functor macro. This will enable params to see the parameters inside, and gpu to move them to the GPU.
Functors.jl has its own notes on basic usage for more details. Additionally, the Advanced Model Building and Customisation page covers the use cases of Functors in greater details.
#
Functors.@functor — Macro
@functor T
@functor T (x,)
Adds methods to functor allowing recursion into objects of type T, and reconstruction. Assumes that T has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not.
By default all fields of T are considered children; this can be restricted be restructed by providing a tuple of field names.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> Functors.children(Foo(1,2))
(x = 1, y = 2)
julia> _, re = Functors.functor(Foo(1,2));
julia> re((10, 20))
Foo(10, 20)
julia> struct TwoThirds a; b; c; end
julia> @functor TwoThirds (a, c)
julia> ch2, re3 = Functors.functor(TwoThirds(10,20,30));
julia> ch2
(a = 10, c = 30)
julia> re3(("ten", "thirty"))
TwoThirds("ten", 20, "thirty")
julia> fmap(x -> 10x, TwoThirds(Foo(1,2), Foo(3,4), 56))
TwoThirds(Foo(10, 20), Foo(3, 4), 560)
#
Functors.fmap — Function
fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk()[, prune])
A structure and type preserving map.
By default it transforms every leaf node (identified by exclude, default isleaf) by applying f, and otherwise traverses x recursively using functor. Optionally, it may also be associated with objects ys with the same tree structure. In that case, f is applied to the corresponding leaf nodes in x and ys.
Examples
julia> fmap(string, (x=1, y=(2, 3)))
(x = "1", y = ("2", "3"))
julia> nt = (a = [1,2], b = [23, (45,), (x=6//7, y=())], c = [8,9]);
julia> fmap(println, nt)
[1, 2]
23
45
6//7
()
[8, 9]
(a = nothing, b = Any[nothing, (nothing,), (x = nothing, y = nothing)], c = nothing)
julia> fmap(println, nt; exclude = x -> x isa Array)
[1, 2]
Any[23, (45,), (x = 6//7, y = ())]
[8, 9]
(a = nothing, b = nothing, c = nothing)
julia> twice = [1, 2]; # println only acts once on this
julia> fmap(println, (i = twice, ii = 34, iii = [5, 6], iv = (twice, 34), v = 34.0))
[1, 2]
34
[5, 6]
34
34.0
(i = nothing, ii = nothing, iii = nothing, iv = (nothing, nothing), v = nothing)
julia> d1 = Dict("x" => [1,2], "y" => 3);
julia> d2 = Dict("x" => [4,5], "y" => 6, "z" => "an_extra_value");
julia> fmap(+, d1, d2) == Dict("x" => [5, 7], "y" => 9) # Note that "z" is ignored
true
Mutable objects which appear more than once are only handled once (by caching f(x) in an IdDict). Thus the relationship x.i === x.iv[1] will be preserved. An immutable object which appears twice is not stored in the cache, thus f(34) will be called twice, and the results will agree only if f is pure.
By default, Tuples, NamedTuples, and some other container-like types in Base have children to recurse into. Arrays of numbers do not. To enable recursion into new types, you must provide a method of functor, which can be done using the macro @functor:
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));
julia> fmap(x -> 10x, m)
Foo(Bar([10, 20, 30]), (40, 50, Bar(Foo(60, 70))))
julia> fmap(string, m)
Foo(Bar("[1, 2, 3]"), ("4", "5", Bar(Foo("6", "7"))))
julia> fmap(string, m, exclude = v -> v isa Bar)
Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))"))
To recurse into custom types without reconstructing them afterwards, use fmapstructure.
For advanced customization of the traversal behaviour, pass a custom walk function that subtypes Functors.AbstractWalk. The call fmap(f, x, ys...; walk = mywalk) will wrap mywalk in ExcludeWalk then CachedWalk. Here, ExcludeWalk is responsible for applying f at excluded nodes. For a low-level interface for executing a user-constructed walk, see execute.
julia> struct MyWalk <: Functors.AbstractWalk end
julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" :
Functors.DefaultWalk()(recurse, x)
julia> fmap(x -> 10x, m; walk = MyWalk())
Foo("hello", (40, 50, "hello"))
The behaviour when the same node appears twice can be altered by giving a value to the prune keyword, which is then used in place of all but the first:
julia> twice = [1, 2];
julia> fmap(float, (x = twice, y = [1,2], z = twice); prune = missing)
(x = [1.0, 2.0], y = [1.0, 2.0], z = missing)
#
Functors.isleaf — Function
Functors.isleaf(x)
Examples
julia> Functors.isleaf(1)
true
julia> Functors.isleaf([2, 3, 4])
true
julia> Functors.isleaf(["five", [6, 7]])
false
julia> Functors.isleaf([])
false
julia> Functors.isleaf((8, 9))
false
julia> Functors.isleaf(())
true
#
Functors.children — Function
Functors.children(x)
Return the children of x as defined by functor. Equivalent to functor(x)[1].
#
Functors.fcollect — Function
fcollect(x; exclude = v -> false)
Traverse x by recursing each child of x as defined by functor and collecting the results into a flat array, ordered by a breadth-first traversal of x, respecting the iteration order of children calls.
Doesn’t recurse inside branches rooted at nodes v for which exclude(v) == true. In such cases, the root v is also excluded from the result. By default, exclude always yields false.
See also children.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> struct TypeWithNoChildren; x; y; end
julia> m = Foo(Bar([1,2,3]), TypeWithNoChildren(:a, :b))
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
TypeWithNoChildren(:a, :b)
julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
TypeWithNoChildren(:a, :b)
julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b))
Bar([1, 2, 3])
#
Functors.functor — Function
Functors.functor(x) = functor(typeof(x), x)
Returns a tuple containing, first, a NamedTuple of the children of x (typically its fields), and second, a reconstruction funciton. This controls the behaviour of fmap.
Methods should be added to functor(::Type{T}, x) for custom types, usually using the macro @functor.
#
Functors.fmapstructure — Function
fmapstructure(f, x; exclude = isleaf)
Like fmap, but doesn’t preserve the type of custom structs. Instead, it returns a NamedTuple (or a Tuple, or an array), or a nested set of these.
Useful for when the output must not contain custom structs.
Examples
julia> struct Foo; x; y; end
julia> @functor Foo
julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]);
julia> fmapstructure(x -> 2x, m)
(x = [2, 4, 6], y = Any[8, (10, 12), (x = 14, y = 16)])
julia> fmapstructure(println, m)
[1, 2, 3]
4
5
6
7
8
(x = nothing, y = Any[nothing, (nothing, nothing), (x = nothing, y = nothing)])
Moving models, or data, to the GPU
Flux provides some convenience functions based on fmap. Some (f16, f32, f64) change the precision of all arrays in a model. Others are used for moving a model to of from GPU memory:
#
Flux.cpu — Function
cpu(m)
Example
julia> m_gpu = Dense(CUDA.randn(2, 5))
Dense(5 => 2) # 12 parameters
julia> m_gpu.bias # matches the given weight matrix
2-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
0.0
0.0
julia> m = m_gpu |> cpu
Dense(5 => 2) # 12 parameters
julia> m.bias
2-element Vector{Float32}:
0.0
0.0
#
Flux.gpu — Method
gpu(m)
Copies m to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time).
On arrays, this calls CUDA’s cu, which also changes arrays with Float64 elements to Float32 while copying them to the device (same for AMDGPU). To act on arrays within a struct, the struct type must be marked with @functor.
See the CUDA.jl docs to help identify the current device.
Example
julia> m = Dense(rand(2, 3)) # constructed with Float64 weight matrix
Dense(3 => 2) # 8 parameters
julia> typeof(m.weight)
Matrix{Float64} (alias for Array{Float64, 2})
julia> m_gpu = gpu(m) # can equivalently be written m_gpu = m |> gpu
Dense(3 => 2) # 8 parameters
julia> typeof(m_gpu.weight)
CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
#
Flux.gpu — Method
gpu(data::DataLoader)
Transforms a given DataLoader to apply gpu to each batch of data, when iterated over. (If no GPU is available, this does nothing.)
Example
julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)
with first element:
(; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})
julia> first(dl)
(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')
julia> c_dl = gpu(dl)
4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)
with first element:
(; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})
julia> first(c_dl).x
2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
1.0 1.0 1.0
1.0 1.0 1.0
For large datasets, this is preferred over moving all the data to the GPU before creating the DataLoader, like this:
julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)
with first element:
(; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})
|
This only works if |