Skip to content

Improve blackjax sampling integration #6963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 23, 2023
Merged

Improve blackjax sampling integration #6963

merged 5 commits into from
Oct 23, 2023

Conversation

junpenglao
Copy link
Member

@junpenglao junpenglao commented Oct 20, 2023

Closes #6550

What is this PR about?
Adding more fine tune control in blackjax sampling.

After this PR:

# Use full rank mass matrix in window adaptation 
with model:
    idata2 = pm.sample(
        nuts_sampler="blackjax", 
        nuts_sampler_kwargs={
            "chain_method": "vectorized",
            "adaptation_kwargs": {
                "is_mass_matrix_diagonal": False
                }
            }
        )

# Warm up and sample with HMC (instead of NUTS) with full rank mass matrix adaptation
with model:
    idata2 = pm.sample(
        nuts_sampler="blackjax", 
        nuts_sampler_kwargs={
            "chain_method": "vectorized",
            "adaptation_kwargs": {
                "is_mass_matrix_diagonal": False,
                "algorithm": "hmc",
                "num_integration_steps": 20,
                }
            }
        )

# Enable progress bar (only works for `chain_method =  "vectorized"`)
with model:
    idata2 = pm.sample(
        nuts_sampler="blackjax", 
        progress_bar=True,
        nuts_sampler_kwargs={
            "chain_method": "vectorized",
            # "chain_method": "parallel",
            "adaptation_kwargs": {
                "is_mass_matrix_diagonal": False,
                }
            }
        )

Checklist

Major / Breaking Changes

None

New features

More option in Blackjax sampling

Bugfixes

None

Documentation

None

Maintenance

None


📚 Documentation preview 📚: https://pymc--6963.org.readthedocs.build/en/6963/

@codecov
Copy link

codecov bot commented Oct 20, 2023

Codecov Report

Merging #6963 (826cbe8) into main (5f29b25) will decrease coverage by 0.03%.
Report is 2 commits behind head on main.
The diff coverage is 80.43%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6963      +/-   ##
==========================================
- Coverage   92.12%   92.09%   -0.03%     
==========================================
  Files         100      100              
  Lines       16859    16875      +16     
==========================================
+ Hits        15531    15541      +10     
- Misses       1328     1334       +6     
Files Coverage Δ
pymc/sampling/mcmc.py 87.69% <ø> (ø)
pymc/sampling/jax.py 93.07% <80.43%> (-3.24%) ⬇️

... and 1 file with indirect coverage changes

@junpenglao
Copy link
Member Author

@jessegrabowski FYI the progress bar is partially working for both warm up and sampling.

@ricardoV94 ricardoV94 changed the title More fine tune control in blackjax sampling. Improve blackjax sampling integration Oct 23, 2023
@junpenglao junpenglao merged commit c3f93ba into main Oct 23, 2023
@junpenglao junpenglao deleted the blackjax branch October 23, 2023 13:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: sample_blackjax_nuts incorrectly reports sampling time
3 participants