Skip to content

Consistent API for user friendly distribution customization #4530

Closed
@ricardoV94

Description

@ricardoV94

The following common "customized distributions" are missing / faulty / only implemented for unobserved variables. It would be nice to have a standard API for when all of these get implemented for both unobserved and observed distributions.

This issue is intended to discuss a possible API going forward when implementing these custom distributions. I will illustrate with the example of a user who wants to create an observable shifted exponential distribution as in #4507. I will call this helper method pm.Shift but probably something like pm.Affine that allows both shifting and scaling would be better.

# We want to shift by a random normal variable
a = pm.Normal('a', 0, 1)

# Adapt Deterministic syntax. Requires naming the raw variable and probably 
# incompatible with current RV registration logic
x = pm.Shift('x', pm.Exponential('x_raw', 1), shift=a, observed=data)

# Adapt syntax used in pm.Mixture. Requires the non intuitive .dist() call
x = pm.Shift('x', pm.Exponential.dist(1), shift=a, observed=data)

# Adapt syntax from pm.Bound. Not very intuitive. Also distribution and transformation 
# parameters are separated from class Names. Not obvious whether data is expected
# to be on the original or shifted scale (should be on the shifted scale)
x = pm.Shift(pm.Exponential, shift=a)('x', 1, observed=data)

# Add a generic "modify" argument
x = pm.Exponential('x', 1, modify=pm.Shift(shift=a), observed=data)

# Add all necessary arguments to all distributions
# sort = bool
# lower / upper = float (for truncation / censoring)
# shift / scale = float 
x = pm.Exponential('x', 1, shift=a, observed=data)

# Add new methods to RVs and (optionally) separate initialization from conditioning
# to remove ambiguity as to whether the data should be on the shifted scale (should be)
x = pm.Exponential('x', 1).shift(a)
x.observe(data)

# Same logic as above but with operator overloading
# Conflicts with normal aesara operators
x = pm.Exponential('x', 1) + a
x.observe(data)

# Other examples of (impractical) operator overloading.
x = pm.Exponential('x', 1, size=2) 
new_x = x + a  # Shifting
new_x = x < a  # Truncating 
new_x = x[x > a] = a  # Censoring
new_x = x[(None, a)]  # Truncating alternative
new_x = x[[None, a]]  # Censoring alternative 
new_x = x[0] < x[1]  # Sorting, more generally x[:-1] < x[1:]
new_x.observe(data)

# Something else?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions