Skip to content

Training a HyperNetwork on MNIST and FashionMNIST

Package Imports

julia
using Lux,
    ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant

Loading Datasets

julia
function load_dataset(
    ::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
    (; features, targets) = if n_train === nothing
        tmp = dset(:train)
        tmp[1:length(tmp)]
    else
        dset(:train)[1:n_train]
    end
    x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    (; features, targets) = if n_eval === nothing
        tmp = dset(:test)
        tmp[1:length(tmp)]
    else
        dset(:test)[1:n_eval]
    end
    x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)

    return (
        DataLoader(
            (x_train, y_train);
            batchsize=min(batchsize, size(x_train, 4)),
            shuffle=true,
            partial=false,
        ),
        DataLoader(
            (x_test, y_test);
            batchsize=min(batchsize, size(x_test, 4)),
            shuffle=false,
            partial=false,
        ),
    )
end

function load_datasets(batchsize=32)
    n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
    n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
    return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end

Implement a HyperNet Layer

julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
    ca_axes = getaxes(
        ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
    )
    return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
        # Generate the weights
        ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
        @return core_network(y, ps_new)
    end
end

Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.

julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
    return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end

Create and Initialize the HyperNet

julia
function create_model()
    core_network = Chain(
        Conv((3, 3), 1 => 16, relu; stride=2),
        Conv((3, 3), 16 => 32, relu; stride=2),
        Conv((3, 3), 32 => 64, relu; stride=2),
        GlobalMeanPool(),
        FlattenLayer(),
        Dense(64, 10),
    )
    return HyperNet(
        Chain(
            Embedding(2 => 32),
            Dense(32, 64, relu),
            Dense(64, Lux.parameterlength(core_network)),
        ),
        core_network,
    )
end

Define Utility Functions

julia
function accuracy(model, ps, st, dataloader, data_idx)
    total_correct, total = 0, 0
    cdev = cpu_device()
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(cdev(y))
        predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

Training

julia
function train()
    dev = reactant_device(; force=true)

    model = create_model()
    dataloaders = dev(load_datasets())

    Random.seed!(1234)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

    x = first(first(dataloaders[1][1]))
    data_idx = ConcreteRNumber(1)
    model_compiled = Reactant.with_config(;
        dot_general_precision=PrecisionConfig.HIGH,
        convolution_precision=PrecisionConfig.HIGH,
    ) do
        @compile model((data_idx, x), ps, Lux.testmode(st))
    end

    ### Let's train the model
    nepochs = 50
    for epoch in 1:nepochs, data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        ### This allows us to trace the data index, else it will be embedded as a constant
        ### in the IR
        concrete_data_idx = ConcreteRNumber(data_idx)

        stime = time()
        for (x, y) in train_dataloader
            (_, _, _, train_state) = Training.single_train_step!(
                AutoEnzyme(),
                CrossEntropyLoss(; logits=Val(true)),
                ((concrete_data_idx, x), y),
                train_state;
                return_gradients=Val(false),
            )
        end
        ttime = time() - stime

        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
                 Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
    end

    println()

    test_acc_list = [0.0, 0.0]
    for data_idx in 1:2
        train_dataloader, test_dataloader = dev.(dataloaders[data_idx])

        concrete_data_idx = ConcreteRNumber(data_idx)
        train_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                train_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )
        test_acc = round(
            accuracy(
                model_compiled,
                train_state.parameters,
                train_state.states,
                test_dataloader,
                concrete_data_idx,
            ) * 100;
            digits=2,
        )

        data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"

        @printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
                 %3.2f%%\n" data_name train_acc test_acc
        test_acc_list[data_idx] = test_acc
    end
    return test_acc_list
end

test_acc_list = train()
2025-08-05 23:39:25.940019: I external/xla/xla/service/service.cc:163] XLA service 0x35c22480 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-05 23:39:25.940092: I external/xla/xla/service/service.cc:171]   StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754437165.941572  443409 se_gpu_pjrt_client.cc:1373] Using BFC allocator.
I0000 00:00:1754437165.941735  443409 gpu_helpers.cc:136] XLA backend allocating 12528893952 bytes on device 0 for BFCAllocator.
I0000 00:00:1754437165.941774  443409 gpu_helpers.cc:177] XLA backend will use up to 4176297984 bytes on device 0 for CollectiveBFCAllocator.
2025-08-05 23:39:25.955495: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
[  1/ 50]	       MNIST	Time 49.67772s	Training Accuracy: 34.57%	Test Accuracy: 37.50%
[  1/ 50]	FashionMNIST	Time 0.08018s	Training Accuracy: 32.62%	Test Accuracy: 43.75%
[  2/ 50]	       MNIST	Time 0.03560s	Training Accuracy: 36.91%	Test Accuracy: 34.38%
[  2/ 50]	FashionMNIST	Time 0.02717s	Training Accuracy: 46.00%	Test Accuracy: 46.88%
[  3/ 50]	       MNIST	Time 0.02361s	Training Accuracy: 41.60%	Test Accuracy: 34.38%
[  3/ 50]	FashionMNIST	Time 0.02477s	Training Accuracy: 53.32%	Test Accuracy: 56.25%
[  4/ 50]	       MNIST	Time 0.04774s	Training Accuracy: 52.15%	Test Accuracy: 43.75%
[  4/ 50]	FashionMNIST	Time 0.02646s	Training Accuracy: 62.89%	Test Accuracy: 59.38%
[  5/ 50]	       MNIST	Time 0.02874s	Training Accuracy: 58.50%	Test Accuracy: 40.62%
[  5/ 50]	FashionMNIST	Time 0.03049s	Training Accuracy: 67.87%	Test Accuracy: 62.50%
[  6/ 50]	       MNIST	Time 0.02891s	Training Accuracy: 65.23%	Test Accuracy: 46.88%
[  6/ 50]	FashionMNIST	Time 0.02992s	Training Accuracy: 74.32%	Test Accuracy: 59.38%
[  7/ 50]	       MNIST	Time 0.04020s	Training Accuracy: 71.29%	Test Accuracy: 43.75%
[  7/ 50]	FashionMNIST	Time 0.02829s	Training Accuracy: 74.90%	Test Accuracy: 56.25%
[  8/ 50]	       MNIST	Time 0.03793s	Training Accuracy: 77.34%	Test Accuracy: 40.62%
[  8/ 50]	FashionMNIST	Time 0.02997s	Training Accuracy: 81.25%	Test Accuracy: 59.38%
[  9/ 50]	       MNIST	Time 0.03427s	Training Accuracy: 81.05%	Test Accuracy: 50.00%
[  9/ 50]	FashionMNIST	Time 0.02909s	Training Accuracy: 83.98%	Test Accuracy: 62.50%
[ 10/ 50]	       MNIST	Time 0.04344s	Training Accuracy: 84.77%	Test Accuracy: 43.75%
[ 10/ 50]	FashionMNIST	Time 0.03245s	Training Accuracy: 87.30%	Test Accuracy: 56.25%
[ 11/ 50]	       MNIST	Time 0.02964s	Training Accuracy: 87.99%	Test Accuracy: 50.00%
[ 11/ 50]	FashionMNIST	Time 0.03926s	Training Accuracy: 89.26%	Test Accuracy: 62.50%
[ 12/ 50]	       MNIST	Time 0.03052s	Training Accuracy: 90.04%	Test Accuracy: 46.88%
[ 12/ 50]	FashionMNIST	Time 0.03800s	Training Accuracy: 91.02%	Test Accuracy: 59.38%
[ 13/ 50]	       MNIST	Time 0.02850s	Training Accuracy: 92.97%	Test Accuracy: 62.50%
[ 13/ 50]	FashionMNIST	Time 0.02789s	Training Accuracy: 92.77%	Test Accuracy: 62.50%
[ 14/ 50]	       MNIST	Time 0.02925s	Training Accuracy: 95.51%	Test Accuracy: 59.38%
[ 14/ 50]	FashionMNIST	Time 0.04921s	Training Accuracy: 94.92%	Test Accuracy: 65.62%
[ 15/ 50]	       MNIST	Time 0.02379s	Training Accuracy: 96.68%	Test Accuracy: 59.38%
[ 15/ 50]	FashionMNIST	Time 0.02415s	Training Accuracy: 95.21%	Test Accuracy: 68.75%
[ 16/ 50]	       MNIST	Time 0.03195s	Training Accuracy: 98.34%	Test Accuracy: 62.50%
[ 16/ 50]	FashionMNIST	Time 0.02649s	Training Accuracy: 95.90%	Test Accuracy: 68.75%
[ 17/ 50]	       MNIST	Time 0.03205s	Training Accuracy: 99.22%	Test Accuracy: 59.38%
[ 17/ 50]	FashionMNIST	Time 0.02390s	Training Accuracy: 97.27%	Test Accuracy: 68.75%
[ 18/ 50]	       MNIST	Time 0.02892s	Training Accuracy: 99.51%	Test Accuracy: 62.50%
[ 18/ 50]	FashionMNIST	Time 0.02793s	Training Accuracy: 96.78%	Test Accuracy: 65.62%
[ 19/ 50]	       MNIST	Time 0.02396s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 19/ 50]	FashionMNIST	Time 0.03471s	Training Accuracy: 99.22%	Test Accuracy: 68.75%
[ 20/ 50]	       MNIST	Time 0.02442s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 20/ 50]	FashionMNIST	Time 0.03641s	Training Accuracy: 99.41%	Test Accuracy: 68.75%
[ 21/ 50]	       MNIST	Time 0.02886s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 21/ 50]	FashionMNIST	Time 0.03564s	Training Accuracy: 99.32%	Test Accuracy: 68.75%
[ 22/ 50]	       MNIST	Time 0.02392s	Training Accuracy: 99.90%	Test Accuracy: 65.62%
[ 22/ 50]	FashionMNIST	Time 0.02372s	Training Accuracy: 99.61%	Test Accuracy: 71.88%
[ 23/ 50]	       MNIST	Time 0.02272s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 23/ 50]	FashionMNIST	Time 0.02335s	Training Accuracy: 99.51%	Test Accuracy: 68.75%
[ 24/ 50]	       MNIST	Time 0.04092s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 24/ 50]	FashionMNIST	Time 0.02622s	Training Accuracy: 99.80%	Test Accuracy: 68.75%
[ 25/ 50]	       MNIST	Time 0.03114s	Training Accuracy: 99.90%	Test Accuracy: 62.50%
[ 25/ 50]	FashionMNIST	Time 0.02498s	Training Accuracy: 100.00%	Test Accuracy: 71.88%
[ 26/ 50]	       MNIST	Time 0.02441s	Training Accuracy: 99.90%	Test Accuracy: 59.38%
[ 26/ 50]	FashionMNIST	Time 0.02416s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 27/ 50]	       MNIST	Time 0.02380s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 27/ 50]	FashionMNIST	Time 0.02394s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 28/ 50]	       MNIST	Time 0.03218s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 28/ 50]	FashionMNIST	Time 0.04132s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 29/ 50]	       MNIST	Time 0.03131s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 29/ 50]	FashionMNIST	Time 0.03665s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 30/ 50]	       MNIST	Time 0.03099s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 30/ 50]	FashionMNIST	Time 0.02853s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 31/ 50]	       MNIST	Time 0.02779s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 31/ 50]	FashionMNIST	Time 0.02826s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 32/ 50]	       MNIST	Time 0.03809s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 32/ 50]	FashionMNIST	Time 0.02910s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 33/ 50]	       MNIST	Time 0.03577s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 33/ 50]	FashionMNIST	Time 0.02746s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 34/ 50]	       MNIST	Time 0.03812s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 34/ 50]	FashionMNIST	Time 0.02887s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 35/ 50]	       MNIST	Time 0.02927s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 35/ 50]	FashionMNIST	Time 0.02843s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 36/ 50]	       MNIST	Time 0.02829s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 36/ 50]	FashionMNIST	Time 0.03794s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 37/ 50]	       MNIST	Time 0.02802s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 37/ 50]	FashionMNIST	Time 0.03656s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 38/ 50]	       MNIST	Time 0.02925s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 38/ 50]	FashionMNIST	Time 0.02876s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 39/ 50]	       MNIST	Time 0.02987s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 39/ 50]	FashionMNIST	Time 0.02738s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 40/ 50]	       MNIST	Time 0.02759s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 40/ 50]	FashionMNIST	Time 0.03024s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 41/ 50]	       MNIST	Time 0.03324s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 41/ 50]	FashionMNIST	Time 0.03030s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 42/ 50]	       MNIST	Time 0.03439s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 42/ 50]	FashionMNIST	Time 0.02455s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 43/ 50]	       MNIST	Time 0.02826s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 43/ 50]	FashionMNIST	Time 0.02517s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 44/ 50]	       MNIST	Time 0.03087s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 44/ 50]	FashionMNIST	Time 0.04696s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 45/ 50]	       MNIST	Time 0.02961s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 45/ 50]	FashionMNIST	Time 0.03151s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 46/ 50]	       MNIST	Time 0.02474s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 46/ 50]	FashionMNIST	Time 0.03064s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 47/ 50]	       MNIST	Time 0.02837s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 47/ 50]	FashionMNIST	Time 0.02881s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 48/ 50]	       MNIST	Time 0.02795s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 48/ 50]	FashionMNIST	Time 0.02819s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 49/ 50]	       MNIST	Time 0.04087s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 49/ 50]	FashionMNIST	Time 0.02202s	Training Accuracy: 100.00%	Test Accuracy: 68.75%
[ 50/ 50]	       MNIST	Time 0.03614s	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[ 50/ 50]	FashionMNIST	Time 0.03683s	Training Accuracy: 100.00%	Test Accuracy: 68.75%

[FINAL]	       MNIST	Training Accuracy: 100.00%	Test Accuracy: 62.50%
[FINAL]	FashionMNIST	Training Accuracy: 100.00%	Test Accuracy: 68.75%

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.