Skip to content

PyMC3 variable is not replaced if provided in more_replacements (VI) #2890

Closed
@ferrine

Description

@ferrine

Description of your problem

PyMC3 variable is not replaced if provided in more_replacements (VI)

Please provide a minimal, self-contained, and reproducible example.

import numpy as np
import pymc3 as pm

def test_var_replacement():
    X_mean = np.linspace(0, 10, 10, dtype='float32')
    y = (np.random.randn(*X_mean.shape) * .05 + X_mean) * 4.
    with pm.Model() as model:
        inp = pm.Normal('X', X_mean, shape=X_mean.shape)
        coef = pm.Normal('b', 4.)
        mean = inp * coef
        pm.Normal('y', mean, .1, observed=y)
        advi = pm.fit(100)
        assert advi.sample_node(mean).eval().shape == (10, )
        x_new = np.linspace(0, 10, 11, dtype='float32')
        assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, )

Please provide the full traceback.

=================================== FAILURES ===================================
_____________________________ test_var_replacement _____________________________

    def test_var_replacement():
        X_mean = np.linspace(0, 10, 10, dtype='float32')
        y = (np.random.randn(*X_mean.shape) * .05 + X_mean) * 4.
        with pm.Model() as model:
            inp = pm.Normal('X', X_mean, shape=X_mean.shape)
            coef = pm.Normal('b', 4.)
            mean = inp * coef
            pm.Normal('y', mean, .1, observed=y)
            advi = pm.fit(100)
            assert advi.sample_node(mean).eval().shape == (10, )
            x_new = np.linspace(0, 10, 11, dtype='float32')
>           assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, )
E           assert (10,) == (11,)
E             At index 0 diff: 10 != 11
E             Use -v to get the full diff

tests/test_variational_inference.py:849: AssertionError
----------------------------- Captured stderr call -----------------------------
Average Loss = 11,613: 100%|██████████| 100/100 [00:00<00:00, 2959.34it/s]
Finished [100%]: Average Loss = 11,134
------------------------------ Captured log call -------------------------------
inference.py               211 INFO     Finished [100%]: Average Loss = 11,134
========================== 1 failed in 13.22 seconds ===========================
Process finished with exit code 0

Please provide any additional information below.
see #2496 (comment)

Versions and main components

  • PyMC3 Version: 3.3
  • Theano Version: 1.0.1
  • Python Version: 3.6
  • Operating system: MacOS
  • How did you install PyMC3: git, master branch

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions