You cannot directly use external packages like pytorch or tf in a pymc model, any more than you could define a pytorch model using tensorflow layers.
Your options are:
- Wrap the NN in a Pytensor
Op
. This blog post shows how to do it with Jax/Flax. - Write the NN yourself directly in pytensor. This example might be helpful.
If the architecture is simple, (2) is probably a much easier option.