Training a HyperNetwork on MNIST and FashionMNIST
Package Imports
julia
using Lux,
ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Reactant
Loading Datasets
julia
function load_dataset(
::Type{dset}, n_train::Union{Nothing,Int}, n_eval::Union{Nothing,Int}, batchsize::Int
) where {dset}
(; features, targets) = if n_train === nothing
tmp = dset(:train)
tmp[1:length(tmp)]
else
dset(:train)[1:n_train]
end
x_train, y_train = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
(; features, targets) = if n_eval === nothing
tmp = dset(:test)
tmp[1:length(tmp)]
else
dset(:test)[1:n_eval]
end
x_test, y_test = reshape(features, 28, 28, 1, :), onehotbatch(targets, 0:9)
return (
DataLoader(
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)),
shuffle=true,
partial=false,
),
DataLoader(
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)),
shuffle=false,
partial=false,
),
)
end
function load_datasets(batchsize=32)
n_train = parse(Bool, get(ENV, "CI", "false")) ? 1024 : nothing
n_eval = parse(Bool, get(ENV, "CI", "false")) ? 32 : nothing
return load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize)
end
Implement a HyperNet Layer
julia
function HyperNet(weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = getaxes(
ComponentArray(Lux.initialparameters(Random.default_rng(), core_network))
)
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
# Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
@return core_network(y, ps_new)
end
end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network
parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator))
end
Create and Initialize the HyperNet
julia
function create_model()
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10),
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network)),
),
core_network,
)
end
Define Utility Functions
julia
function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
Training
julia
function train()
dev = reactant_device(; force=true)
model = create_model()
dataloaders = dev(load_datasets())
Random.seed!(1234)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))
x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile model((data_idx, x), ps, Lux.testmode(st))
end
### Let's train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)
stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoEnzyme(),
CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y),
train_state;
return_gradients=Val(false),
)
end
ttime = time() - stime
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[%3d/%3d]\t%12s\tTime %3.5fs\tTraining Accuracy: %3.2f%%\tTest \
Accuracy: %3.2f%%\n" epoch nepochs data_name ttime train_acc test_acc
end
println()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dev.(dataloaders[data_idx])
concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
train_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
test_acc = round(
accuracy(
model_compiled,
train_state.parameters,
train_state.states,
test_dataloader,
concrete_data_idx,
) * 100;
digits=2,
)
data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
@printf "[FINAL]\t%12s\tTraining Accuracy: %3.2f%%\tTest Accuracy: \
%3.2f%%\n" data_name train_acc test_acc
test_acc_list[data_idx] = test_acc
end
return test_acc_list
end
test_acc_list = train()
2025-08-05 23:39:25.940019: I external/xla/xla/service/service.cc:163] XLA service 0x35c22480 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-08-05 23:39:25.940092: I external/xla/xla/service/service.cc:171] StreamExecutor device (0): Quadro RTX 5000, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1754437165.941572 443409 se_gpu_pjrt_client.cc:1373] Using BFC allocator.
I0000 00:00:1754437165.941735 443409 gpu_helpers.cc:136] XLA backend allocating 12528893952 bytes on device 0 for BFCAllocator.
I0000 00:00:1754437165.941774 443409 gpu_helpers.cc:177] XLA backend will use up to 4176297984 bytes on device 0 for CollectiveBFCAllocator.
2025-08-05 23:39:25.955495: I external/xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 90800
[ 1/ 50] MNIST Time 49.67772s Training Accuracy: 34.57% Test Accuracy: 37.50%
[ 1/ 50] FashionMNIST Time 0.08018s Training Accuracy: 32.62% Test Accuracy: 43.75%
[ 2/ 50] MNIST Time 0.03560s Training Accuracy: 36.91% Test Accuracy: 34.38%
[ 2/ 50] FashionMNIST Time 0.02717s Training Accuracy: 46.00% Test Accuracy: 46.88%
[ 3/ 50] MNIST Time 0.02361s Training Accuracy: 41.60% Test Accuracy: 34.38%
[ 3/ 50] FashionMNIST Time 0.02477s Training Accuracy: 53.32% Test Accuracy: 56.25%
[ 4/ 50] MNIST Time 0.04774s Training Accuracy: 52.15% Test Accuracy: 43.75%
[ 4/ 50] FashionMNIST Time 0.02646s Training Accuracy: 62.89% Test Accuracy: 59.38%
[ 5/ 50] MNIST Time 0.02874s Training Accuracy: 58.50% Test Accuracy: 40.62%
[ 5/ 50] FashionMNIST Time 0.03049s Training Accuracy: 67.87% Test Accuracy: 62.50%
[ 6/ 50] MNIST Time 0.02891s Training Accuracy: 65.23% Test Accuracy: 46.88%
[ 6/ 50] FashionMNIST Time 0.02992s Training Accuracy: 74.32% Test Accuracy: 59.38%
[ 7/ 50] MNIST Time 0.04020s Training Accuracy: 71.29% Test Accuracy: 43.75%
[ 7/ 50] FashionMNIST Time 0.02829s Training Accuracy: 74.90% Test Accuracy: 56.25%
[ 8/ 50] MNIST Time 0.03793s Training Accuracy: 77.34% Test Accuracy: 40.62%
[ 8/ 50] FashionMNIST Time 0.02997s Training Accuracy: 81.25% Test Accuracy: 59.38%
[ 9/ 50] MNIST Time 0.03427s Training Accuracy: 81.05% Test Accuracy: 50.00%
[ 9/ 50] FashionMNIST Time 0.02909s Training Accuracy: 83.98% Test Accuracy: 62.50%
[ 10/ 50] MNIST Time 0.04344s Training Accuracy: 84.77% Test Accuracy: 43.75%
[ 10/ 50] FashionMNIST Time 0.03245s Training Accuracy: 87.30% Test Accuracy: 56.25%
[ 11/ 50] MNIST Time 0.02964s Training Accuracy: 87.99% Test Accuracy: 50.00%
[ 11/ 50] FashionMNIST Time 0.03926s Training Accuracy: 89.26% Test Accuracy: 62.50%
[ 12/ 50] MNIST Time 0.03052s Training Accuracy: 90.04% Test Accuracy: 46.88%
[ 12/ 50] FashionMNIST Time 0.03800s Training Accuracy: 91.02% Test Accuracy: 59.38%
[ 13/ 50] MNIST Time 0.02850s Training Accuracy: 92.97% Test Accuracy: 62.50%
[ 13/ 50] FashionMNIST Time 0.02789s Training Accuracy: 92.77% Test Accuracy: 62.50%
[ 14/ 50] MNIST Time 0.02925s Training Accuracy: 95.51% Test Accuracy: 59.38%
[ 14/ 50] FashionMNIST Time 0.04921s Training Accuracy: 94.92% Test Accuracy: 65.62%
[ 15/ 50] MNIST Time 0.02379s Training Accuracy: 96.68% Test Accuracy: 59.38%
[ 15/ 50] FashionMNIST Time 0.02415s Training Accuracy: 95.21% Test Accuracy: 68.75%
[ 16/ 50] MNIST Time 0.03195s Training Accuracy: 98.34% Test Accuracy: 62.50%
[ 16/ 50] FashionMNIST Time 0.02649s Training Accuracy: 95.90% Test Accuracy: 68.75%
[ 17/ 50] MNIST Time 0.03205s Training Accuracy: 99.22% Test Accuracy: 59.38%
[ 17/ 50] FashionMNIST Time 0.02390s Training Accuracy: 97.27% Test Accuracy: 68.75%
[ 18/ 50] MNIST Time 0.02892s Training Accuracy: 99.51% Test Accuracy: 62.50%
[ 18/ 50] FashionMNIST Time 0.02793s Training Accuracy: 96.78% Test Accuracy: 65.62%
[ 19/ 50] MNIST Time 0.02396s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 19/ 50] FashionMNIST Time 0.03471s Training Accuracy: 99.22% Test Accuracy: 68.75%
[ 20/ 50] MNIST Time 0.02442s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 20/ 50] FashionMNIST Time 0.03641s Training Accuracy: 99.41% Test Accuracy: 68.75%
[ 21/ 50] MNIST Time 0.02886s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 21/ 50] FashionMNIST Time 0.03564s Training Accuracy: 99.32% Test Accuracy: 68.75%
[ 22/ 50] MNIST Time 0.02392s Training Accuracy: 99.90% Test Accuracy: 65.62%
[ 22/ 50] FashionMNIST Time 0.02372s Training Accuracy: 99.61% Test Accuracy: 71.88%
[ 23/ 50] MNIST Time 0.02272s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 23/ 50] FashionMNIST Time 0.02335s Training Accuracy: 99.51% Test Accuracy: 68.75%
[ 24/ 50] MNIST Time 0.04092s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 24/ 50] FashionMNIST Time 0.02622s Training Accuracy: 99.80% Test Accuracy: 68.75%
[ 25/ 50] MNIST Time 0.03114s Training Accuracy: 99.90% Test Accuracy: 62.50%
[ 25/ 50] FashionMNIST Time 0.02498s Training Accuracy: 100.00% Test Accuracy: 71.88%
[ 26/ 50] MNIST Time 0.02441s Training Accuracy: 99.90% Test Accuracy: 59.38%
[ 26/ 50] FashionMNIST Time 0.02416s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 27/ 50] MNIST Time 0.02380s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 27/ 50] FashionMNIST Time 0.02394s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 28/ 50] MNIST Time 0.03218s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 28/ 50] FashionMNIST Time 0.04132s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 29/ 50] MNIST Time 0.03131s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 29/ 50] FashionMNIST Time 0.03665s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 30/ 50] MNIST Time 0.03099s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 30/ 50] FashionMNIST Time 0.02853s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 31/ 50] MNIST Time 0.02779s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 31/ 50] FashionMNIST Time 0.02826s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 32/ 50] MNIST Time 0.03809s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 32/ 50] FashionMNIST Time 0.02910s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 33/ 50] MNIST Time 0.03577s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 33/ 50] FashionMNIST Time 0.02746s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 34/ 50] MNIST Time 0.03812s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 34/ 50] FashionMNIST Time 0.02887s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 35/ 50] MNIST Time 0.02927s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 35/ 50] FashionMNIST Time 0.02843s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 36/ 50] MNIST Time 0.02829s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 36/ 50] FashionMNIST Time 0.03794s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 37/ 50] MNIST Time 0.02802s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 37/ 50] FashionMNIST Time 0.03656s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 38/ 50] MNIST Time 0.02925s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 38/ 50] FashionMNIST Time 0.02876s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 39/ 50] MNIST Time 0.02987s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 39/ 50] FashionMNIST Time 0.02738s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 40/ 50] MNIST Time 0.02759s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 40/ 50] FashionMNIST Time 0.03024s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 41/ 50] MNIST Time 0.03324s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 41/ 50] FashionMNIST Time 0.03030s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 42/ 50] MNIST Time 0.03439s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 42/ 50] FashionMNIST Time 0.02455s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 43/ 50] MNIST Time 0.02826s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 43/ 50] FashionMNIST Time 0.02517s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 44/ 50] MNIST Time 0.03087s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 44/ 50] FashionMNIST Time 0.04696s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 45/ 50] MNIST Time 0.02961s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 45/ 50] FashionMNIST Time 0.03151s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 46/ 50] MNIST Time 0.02474s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 46/ 50] FashionMNIST Time 0.03064s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 47/ 50] MNIST Time 0.02837s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 47/ 50] FashionMNIST Time 0.02881s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 48/ 50] MNIST Time 0.02795s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 48/ 50] FashionMNIST Time 0.02819s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 49/ 50] MNIST Time 0.04087s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 49/ 50] FashionMNIST Time 0.02202s Training Accuracy: 100.00% Test Accuracy: 68.75%
[ 50/ 50] MNIST Time 0.03614s Training Accuracy: 100.00% Test Accuracy: 62.50%
[ 50/ 50] FashionMNIST Time 0.03683s Training Accuracy: 100.00% Test Accuracy: 68.75%
[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 62.50%
[FINAL] FashionMNIST Training Accuracy: 100.00% Test Accuracy: 68.75%
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.