Hi there, I have set up a Hierarchical Bayes model for choice data (on AWS Sagemaker) and am able to use NUTS sampler in PyMC4 to take samples. Now I’m trying to run the sampling on GPU. My Sagemaker instance has GPU available. I tried to get it to work using a .aesara.rc file by setting device=cuda/cuda0/gpu but none of these work and only device=cpu works. I hear that GPU acceleration is supposed to be straightforward on PyMC4. Could someone please guide me in the right direction for that? More details of my model and sampling can be found here:
I can also provide more information if need be.
Thanks in advance!
Great, and you’re calling pymc.sampling_jax.sample_numpyro_nuts() to do your sampling? Sorry if these seem like basic questions, just trying to make sure we’re on the same page.
I have realized that GPU “is” being utilized. Please see earlier post on this thread. Apparently, once JAX has GPU access, you’re done. PyMC4 utilizes GPU via JAX.
Great! Glad you got it working. And yeah, when I was setting up to use the GPU, I was preparing for a fight with the computer, too. It was far easier than I expected.
Hope you’re doing good !
I’m trying to achieve what you did, but JAX wont detect GPU on AWS Sagemaker. Could you please tell me what instance/image did you use?