Skip to content

Initval refactoring #4924

Closed
Closed
@ricardoV94

Description

@ricardoV94

Proposal

Create a model.compile_initial_point_fn which creates a compiled aesara function that computes a (transformed) initial point for each RV in the model simultaneously.

This method takes as input a dictionary with optional user choices concerning initvals, whose values might be either numerical values, symbolic expressions (which can only depend on upstream RVs / Deterministics / or shared variables), or a string of ["random", "moment"], which means a random draw should be taken from the RV or a fixed moment should be extracted from the RV (using something like #4912).

These choices would be saved in a model dictionary such as model._user_initval_choices as new variables are defined.

Finally model.initial_point would look something like this:

def initial_point(self, recompute=False):
    if recompute or self.last_computed_initval is None:
        initial_point = self.compile_initial_point_fn(self._user_initval_choices)():
        ... # convert to dict format
        self.last_computed_initval = initial_point_dict
    return self.last_computed_initval

The main goal is to decouple the model definition from the sampling phases, addressing issues like #4918

A recompute flag is used to ensure a new initial_point is not recomputed uselessly, as many places in the codebase call this property frequently. Alternatively we can keep the old property and add a recompute_inital_point() that changes the self.last_computed_initval

Other benefits

This would allow us to simplify the slightly redundant and potentially defective model.update_start_vals (mentioned in #4484 (comment)): https://github.com/pymc-devs/pymc3/blob/6a75744b31f3ec015856ac6ea374fe12be8cc156/pymc3/model.py#L1548-L1556

to something like:

def update_start_vals(self, new_initval_choices):
    initval_choices = deepcopy.copy(self._user_initval_choices)
    initval_choices.update(new_initval_choices)
    start_vals = model.compile_initial_point_fn(initval_choices)()
    ... # convert to dict
    return start_vals_dict

We could also revert some ugly changes to prior_predictive_sampling introduced by yours truly in 687f044:

https://github.com/pymc-devs/pymc3/blob/6d2aa5ddebed01d81c2ab66b9d4bd02194f82508/pymc3/sampling.py#L1983-L1999

This hack resulted from a different difficulty in obtaining multiple transformed initial_points (to kick-start the SMC sampler). All that would be needed after the proposed changes would be something like the following:

random_initvals = {var: "random" for var in model._user_initval_choices.keys()}
initial_point_fn = model.compile_initial_point_fn(random_initvals)
initvals = zip(*(initial_point_fn() for i in range(samples)))
...  # convert to dict format

And the changes in 687f044 could be removed altogether.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions