How to combine pymc with neural network

You cannot directly use external packages like pytorch or tf in a pymc model, any more than you could define a pytorch model using tensorflow layers.

Your options are:

  1. Wrap the NN in a Pytensor Op. This blog post shows how to do it with Jax/Flax.
  2. Write the NN yourself directly in pytensor. This example might be helpful.

If the architecture is simple, (2) is probably a much easier option.

1 Like