Closed
Description
Description of your problem
PyMC3 variable is not replaced if provided in more_replacements (VI)
Please provide a minimal, self-contained, and reproducible example.
import numpy as np
import pymc3 as pm
def test_var_replacement():
X_mean = np.linspace(0, 10, 10, dtype='float32')
y = (np.random.randn(*X_mean.shape) * .05 + X_mean) * 4.
with pm.Model() as model:
inp = pm.Normal('X', X_mean, shape=X_mean.shape)
coef = pm.Normal('b', 4.)
mean = inp * coef
pm.Normal('y', mean, .1, observed=y)
advi = pm.fit(100)
assert advi.sample_node(mean).eval().shape == (10, )
x_new = np.linspace(0, 10, 11, dtype='float32')
assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, )
Please provide the full traceback.
=================================== FAILURES ===================================
_____________________________ test_var_replacement _____________________________
def test_var_replacement():
X_mean = np.linspace(0, 10, 10, dtype='float32')
y = (np.random.randn(*X_mean.shape) * .05 + X_mean) * 4.
with pm.Model() as model:
inp = pm.Normal('X', X_mean, shape=X_mean.shape)
coef = pm.Normal('b', 4.)
mean = inp * coef
pm.Normal('y', mean, .1, observed=y)
advi = pm.fit(100)
assert advi.sample_node(mean).eval().shape == (10, )
x_new = np.linspace(0, 10, 11, dtype='float32')
> assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, )
E assert (10,) == (11,)
E At index 0 diff: 10 != 11
E Use -v to get the full diff
tests/test_variational_inference.py:849: AssertionError
----------------------------- Captured stderr call -----------------------------
Average Loss = 11,613: 100%|██████████| 100/100 [00:00<00:00, 2959.34it/s]
Finished [100%]: Average Loss = 11,134
------------------------------ Captured log call -------------------------------
inference.py 211 INFO Finished [100%]: Average Loss = 11,134
========================== 1 failed in 13.22 seconds ===========================
Process finished with exit code 0
Please provide any additional information below.
see #2496 (comment)
Versions and main components
- PyMC3 Version: 3.3
- Theano Version: 1.0.1
- Python Version: 3.6
- Operating system: MacOS
- How did you install PyMC3: git, master branch