Closed
Description
The following common "customized distributions" are missing / faulty / only implemented for unobserved variables. It would be nice to have a standard API for when all of these get implemented for both unobserved and observed distributions.
- Mixture (not yet in V4, but otherwise working, similar API via
.dist
also present in the LKJCholeskyCov) - Sort (Behaves a bit like the ordered transform but also works in ancestral sampling? transforms.ordered doesn't apply to samples found with .random() #4213, see @junpenglao thread https://twitter.com/junpenglao/status/1263927315194576898)
- Truncate (Bound works for unobserved only as it doesn't respect the change in the normalization constant, Design for bound, truncated and censored distributions #1864)
- Censor (Similar to truncation but values stay "sticked" at the edges, Design for bound, truncated and censored distributions #1864)
- Shift, Scale (is there a verb for affine?) (Can be done with verbose deterministics for unobserved variables, not available for observed, Feature request: Add offset parameter to Exponential class (akin to mu in Laplace class) #4507, Choose parametrization in sampling statement #1924?)
- Other?
This issue is intended to discuss a possible API going forward when implementing these custom distributions. I will illustrate with the example of a user who wants to create an observable shifted exponential distribution as in #4507. I will call this helper method pm.Shift
but probably something like pm.Affine
that allows both shifting and scaling would be better.
# We want to shift by a random normal variable
a = pm.Normal('a', 0, 1)
# Adapt Deterministic syntax. Requires naming the raw variable and probably
# incompatible with current RV registration logic
x = pm.Shift('x', pm.Exponential('x_raw', 1), shift=a, observed=data)
# Adapt syntax used in pm.Mixture. Requires the non intuitive .dist() call
x = pm.Shift('x', pm.Exponential.dist(1), shift=a, observed=data)
# Adapt syntax from pm.Bound. Not very intuitive. Also distribution and transformation
# parameters are separated from class Names. Not obvious whether data is expected
# to be on the original or shifted scale (should be on the shifted scale)
x = pm.Shift(pm.Exponential, shift=a)('x', 1, observed=data)
# Add a generic "modify" argument
x = pm.Exponential('x', 1, modify=pm.Shift(shift=a), observed=data)
# Add all necessary arguments to all distributions
# sort = bool
# lower / upper = float (for truncation / censoring)
# shift / scale = float
x = pm.Exponential('x', 1, shift=a, observed=data)
# Add new methods to RVs and (optionally) separate initialization from conditioning
# to remove ambiguity as to whether the data should be on the shifted scale (should be)
x = pm.Exponential('x', 1).shift(a)
x.observe(data)
# Same logic as above but with operator overloading
# Conflicts with normal aesara operators
x = pm.Exponential('x', 1) + a
x.observe(data)
# Other examples of (impractical) operator overloading.
x = pm.Exponential('x', 1, size=2)
new_x = x + a # Shifting
new_x = x < a # Truncating
new_x = x[x > a] = a # Censoring
new_x = x[(None, a)] # Truncating alternative
new_x = x[[None, a]] # Censoring alternative
new_x = x[0] < x[1] # Sorting, more generally x[:-1] < x[1:]
new_x.observe(data)
# Something else?