Description
Proposal
Create a model.compile_initial_point_fn
which creates a compiled aesara function that computes a (transformed) initial point for each RV in the model simultaneously.
This method takes as input a dictionary with optional user choices concerning initvals, whose values might be either numerical values, symbolic expressions (which can only depend on upstream RVs / Deterministics / or shared variables), or a string of ["random", "moment"]
, which means a random draw should be taken from the RV or a fixed moment should be extracted from the RV (using something like #4912).
These choices would be saved in a model dictionary such as model._user_initval_choices
as new variables are defined.
Finally model.initial_point
would look something like this:
def initial_point(self, recompute=False):
if recompute or self.last_computed_initval is None:
initial_point = self.compile_initial_point_fn(self._user_initval_choices)():
... # convert to dict format
self.last_computed_initval = initial_point_dict
return self.last_computed_initval
The main goal is to decouple the model definition from the sampling phases, addressing issues like #4918
A recompute
flag is used to ensure a new initial_point is not recomputed uselessly, as many places in the codebase call this property frequently. Alternatively we can keep the old property and add a recompute_inital_point()
that changes the self.last_computed_initval
Other benefits
This would allow us to simplify the slightly redundant and potentially defective model.update_start_vals
(mentioned in #4484 (comment)): https://github.com/pymc-devs/pymc3/blob/6a75744b31f3ec015856ac6ea374fe12be8cc156/pymc3/model.py#L1548-L1556
to something like:
def update_start_vals(self, new_initval_choices):
initval_choices = deepcopy.copy(self._user_initval_choices)
initval_choices.update(new_initval_choices)
start_vals = model.compile_initial_point_fn(initval_choices)()
... # convert to dict
return start_vals_dict
We could also revert some ugly changes to prior_predictive_sampling
introduced by yours truly in 687f044:
This hack resulted from a different difficulty in obtaining multiple transformed initial_points
(to kick-start the SMC sampler). All that would be needed after the proposed changes would be something like the following:
random_initvals = {var: "random" for var in model._user_initval_choices.keys()}
initial_point_fn = model.compile_initial_point_fn(random_initvals)
initvals = zip(*(initial_point_fn() for i in range(samples)))
... # convert to dict format
And the changes in 687f044 could be removed altogether.