Skip to content

Scan RVs cannot be composed with other measurable operations #6351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Nov 29, 2022 · 1 comment
Open

Scan RVs cannot be composed with other measurable operations #6351

ricardoV94 opened this issue Nov 29, 2022 · 1 comment

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 29, 2022

Description

import aesara.tensor as at
from pymc.logprob import factorized_joint_logprob

x, x_updates = aesara.scan(
    fn=lambda: at.random.normal(0, 1),
    n_steps=10,
)
x.name = "x"
x_vv = x.clone()

factorized_joint_logprob({x: x_vv})  # Fine

sx = at.log(x)
sx.name = "sx"
sx_vv = sx.clone()
factorized_joint_logprob({sx: sx_vv}) 
# RuntimeError: The logprob terms of the following value variables could not be derived: {sx}

This is probably related to how Scan acts on value variables directly.

pymc/pymc/logprob/scan.py

Lines 420 to 469 in e0d25c8

# We're going to replace the user's random variable/value variable mappings
# with ones that map directly to outputs of this `Scan`.
for rv_var, val_var, out_idx in indirect_rv_vars:
# The full/un-`Subtensor`ed `Scan` output that we need to use
full_out = node.outputs[out_idx]
assert rv_var.owner.inputs[0] == full_out
# A new value variable that spans the full output.
# We don't want the old graph to appear in the new log-probability
# graph, so we use the shape feature to (hopefully) get the shape
# without the entire `Scan` itself.
full_out_shape = tuple(
fgraph.shape_feature.get_shape(full_out, i) for i in range(full_out.ndim)
)
new_val_var = at.empty(full_out_shape, dtype=full_out.dtype)
# Set the parts of this new value variable that applied to the
# user-specified value variable to the user's value variable
subtensor_indices = indices_from_subtensor(
rv_var.owner.inputs[1:], rv_var.owner.op.idx_list
)
# E.g. for a single `-1` TAPS, `s_0T[1:] = s_1T` where `s_0T` is
# `new_val_var` and `s_1T` is the user-specified value variable
# that only spans times `t=1` to `t=T`.
new_val_var = at.set_subtensor(new_val_var[subtensor_indices], val_var)
# This is the outer-input that sets `s_0T[i] = taps[i]` where `i`
# is a TAP index (e.g. a TAP of `-1` maps to index `0` in a vector
# of the entire series).
var_info = curr_scanargs.find_among_fields(full_out)
alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :]
outer_input_var = getattr(curr_scanargs, f"outer_in_{alt_type}")[var_info.index]
# These outer-inputs are using by `aesara.scan.utils.expand_empty`, and
# are expected to consist of only a single `set_subtensor` call.
# That's why we can simply replace the first argument of the node.
assert isinstance(outer_input_var.owner.op, inc_subtensor_ops)
# We're going to set those values on our `new_val_var` so that it can
# serve as a complete replacement for the old input `outer_input_var`.
# from aesara.graph import clone_replace
#
new_val_var = outer_input_var.owner.clone_with_new_inputs(
[new_val_var] + outer_input_var.owner.inputs[1:]
).default_output()
# Replace the mapping
rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out)

@ricardoV94
Copy link
Member Author

This would be nice for timeseries models, in order to concatenate the init and innovations components.

For example the (non-user API) example with an AR2: https://gist.github.com/ricardoV94/a49b2cc1cf0f32a5f6dc31d6856ccb63#file-pymc_timeseries-ipynb

Could instead be written as

import pymc as pm
import pytensor

lags = 2
trials = 100

coords = {"lags": range(lags), "trials": range(trials)}
with pm.Model(coords=coords, check_bounds=False) as m:
    rho = pm.Normal("rho", 0, 0.2, dims=("lags",))
    sigma = pm.HalfNormal("sigma", .2)

    ar_init = pm.Normal.dist(shape=(lags,))

    def ar_step(x_tm2, x_tm1, rho, sigma):
        x = pm.Normal.dist(mu=x_tm1 * rho[0] + x_tm2 * rho[1], sigma=sigma)
        return x

    ar_innov, ar_innov_updates = pytensor.scan(
        fn=ar_step,
        outputs_info=[{"initial": ar_init, "taps": range(-lags, 0)}],
        non_sequences=[rho, sigma],
        n_steps=trials - lags,
    )

    ar = pm.math.concatenate([ar_init, ar_innov], axis=-1)
    ar = m.register_rv(ar, name="ar", observed=np.random.normal(size=(100,)), dims=("trials"))

But due to the concatenate with a Scan output, PyMC cannot infer the logp.

m.logp()  # RuntimeError: ... could not be derived

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant