Skip to content

inferencedata.log_likelihood is summing observations #5236

Closed
@ricardoV94

Description

@ricardoV94

When talking with @lucianopaz I realized we completely broke log_likelihood computation in V4.

import pymc as pm
with pm.Model() as m:
    y = pm.Normal("y")
    x = pm.Normal("x", y, 1, observed=[5, 2])    
    idata = pm.sample(tune=5, draws=5, chains=2)
print(idata.log_likelihood['x'].values.shape)
# (2, 5, 1)

Whereas in V3:

import pymc3 as pm
with pm.Model() as m:
    y = pm.Normal("y")
    x = pm.Normal("x", y, 1, observed=[5, 2])    
    idata = pm.sample(tune=5, draws=5, chains=2, return_inferencedata=True)
print(idata.log_likelihood['x'].values.shape)
# (2, 5, 2)

This happened because the default model.logpt now returns the summed logp by default whereas before it returned the vectorized logp by default. The change was done in 0a172c8

Although that is a more sane default, we have to reintroduce an easy helper logp_elemwiset (I think this is pretty much broken right now as well) which calls logpt with sum=False.

Also in this case we might want to just return the logprob terms as the dictionary items that are returned by aeppl.factorized_joint_lopgrob and let the end-user decide how he wants to combine them. These keys contain {value variable: logp term}. The default of calling at.add on all variables when sum=False is seldom useful (that's why we switched the default), due to potential unwanted broadcasting across variables with different dimensions.

One extra advantage of returning the dictionary items is that we don't need to create nearly duplicated graphs for each observed variable when computing the log-likelihood here:

cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs]

We can request it for any number of observed variables at the same time, and then simply compile a function that has each variable logp term as an output, but otherwise shares the common nodes, saving on compilation, computation and memory footprint, when a model has more than one observed variable.

For instance, this nested loop would no longer be needed:

pymc/pymc/backends/arviz.py

Lines 276 to 282 in fe2d101

for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)

CC @OriolAbril

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions