Skip to content

MarginalApprox doesn't allow non-constant covariance parameters or inducing point locations in v4 #5922

Closed
@quantheory

Description

@quantheory

Description of your problem

The following code that worked in PyMC3 (and works in v4 if Marginal is used instead of MarginalApprox) is no longer functional:

x = np.linspace(0., 1., 10)
xu = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    sigma_gp = pm.HalfNormal('sigma_gp', sigma=1.)
    l = pm.HalfNormal('l', sigma=0.1)
    cov = sigma_gp**2 * pm.gp.cov.Matern32(1, ls=[l])
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu[:,None])
    maxpost = pm.find_MAP()
    print(maxpost)
Complete error traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_33029/329055196.py in <module>
     11     gp.marginal_likelihood('like', X=x[:,None], y=y,
     12                            noise=sigma_noise, Xu=xu[:,None])
---> 13     maxpost = pm.find_MAP()
     14     print(maxpost)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/tuning/starting.py in find_MAP(start, vars, method, return_raw, include_transformed, progressbar, maxeval, model, seed, *args, **kwargs)
    109     )
    110     start = ipfn(seed)
--> 111     model.check_start_vals(start)
    112 
    113     var_names = {var.name for var in vars}

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in check_start_vals(self, start)
   1785                 )
   1786 
-> 1787             initial_eval = self.point_logps(point=elem)
   1788 
   1789             if not all(np.isfinite(v) for v in initial_eval.values()):

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in point_logps(self, point, round_vals)
   1821 
   1822         factors = self.basic_RVs + self.potentials
-> 1823         factor_logps_fn = [at.sum(factor) for factor in self.logp(factors, sum=False)]
   1824         return {
   1825             factor.name: np.round(np.asarray(factor_logp), round_vals)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in logp(self, vars, jacobian, sum)
    751         rv_logps: List[TensorVariable] = []
    752         if rv_values:
--> 753             rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    754             assert isinstance(rv_logps, list)
    755 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    255     ]
    256     if unexpected_rv_nodes:
--> 257         raise ValueError(
    258             f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
    259             "This can happen when DensityDist logp or Interval transform functions "

ValueError: Random variables detected in the logp graph: [l, sigma_gp].
This can happen when DensityDist logp or Interval transform functions reference nonlocal variables.

This is probably related to some degree to the DensityDist changes and to #5024. Note that the following case where sigma_noise is the only random variable works fine:

x = np.linspace(0., 1., 10)
xu = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    cov = pm.gp.cov.Matern32(1, ls=[1.])
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu[:,None])
    maxpost = pm.find_MAP()
    print(maxpost)

yielding

{'sigma_noise_log__': array(-2.8405184), 'sigma_noise': array(0.05839539)}

I'm guessing that this is why this error was not caught earlier. MarginalApprox needs a test to be added with non-constant length scales.

There's a completely different error if the inducing point locations are non-constant. Example code:

x = np.linspace(0., 1., 10)
xu_init = np.linspace(0., 1., 5)
y = np.sin(x)

with pm.Model():
    cov = pm.gp.cov.Matern32(1, ls=[1.])
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='VFE')
    sigma_noise = pm.HalfNormal('sigma_noise', sigma=1.)
    xu = pm.Flat("xu", shape=(5, 1), initval=xu_init[:,None])
    gp.marginal_likelihood('like', X=x[:,None], y=y,
                           noise=sigma_noise, Xu=xu)
    maxpost = pm.find_MAP()
    print(maxpost)
Complete error traceback
---------------------------------------------------------------------------
MissingInputError                         Traceback (most recent call last)
/tmp/ipykernel_33029/1554036267.py in <module>
     10     gp.marginal_likelihood('like', X=x[:,None], y=y,
     11                            noise=sigma_noise, Xu=xu)
---> 12     maxpost = pm.find_MAP()
     13     print(maxpost)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/tuning/starting.py in find_MAP(start, vars, method, return_raw, include_transformed, progressbar, maxeval, model, seed, *args, **kwargs)
    109     )
    110     start = ipfn(seed)
--> 111     model.check_start_vals(start)
    112 
    113     var_names = {var.name for var in vars}

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in check_start_vals(self, start)
   1785                 )
   1786 
-> 1787             initial_eval = self.point_logps(point=elem)
   1788 
   1789             if not all(np.isfinite(v) for v in initial_eval.values()):

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in point_logps(self, point, round_vals)
   1821 
   1822         factors = self.basic_RVs + self.potentials
-> 1823         factor_logps_fn = [at.sum(factor) for factor in self.logp(factors, sum=False)]
   1824         return {
   1825             factor.name: np.round(np.asarray(factor_logp), round_vals)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py in logp(self, vars, jacobian, sum)
    751         rv_logps: List[TensorVariable] = []
    752         if rv_values:
--> 753             rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    754             assert isinstance(rv_logps, list)
    755 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    233 
    234     transform_opt = TransformValuesOpt(transform_map)
--> 235     temp_logp_var_dict = factorized_joint_logprob(
    236         tmp_rvs_to_values,
    237         extra_rewrites=transform_opt,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aeppl/joint_logprob.py in factorized_joint_logprob(rv_values, warn_missing_rvs, extra_rewrites, **kwargs)
    145         q_rv_inputs = remapped_vars[len(q_value_vars) :]
    146 
--> 147         q_logprob_vars = _logprob(
    148             node.op,
    149             q_value_vars,

~/anaconda3/envs/pymc3/lib/python3.9/functools.py in wrapper(*args, **kw)
    875                             '1 positional argument')
    876 
--> 877         return dispatch(args[0].__class__)(*args, **kw)
    878 
    879     funcname = getattr(func, '__name__', 'singledispatch function')

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/distribution.py in density_dist_logp(op, value_var_list, *dist_params, **kwargs)
    795             _dist_params = dist_params[3:]
    796             value_var = value_var_list[0]
--> 797             return logp(value_var, *_dist_params)
    798 
    799         @_logcdf.register(rv_type)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/gp.py in _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter)
    685     def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
    686         sigma2 = at.square(sigma)
--> 687         Kuu = self.cov_func(Xu)
    688         Kuf = self.cov_func(Xu, X)
    689         Luu = cholesky(stabilize(Kuu, jitter))

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/cov.py in __call__(self, X, Xs, diag)
     84             return self.diag(X)
     85         else:
---> 86             return self.full(X, Xs)
     87 
     88     def diag(self, X):

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/cov.py in full(self, X, Xs)
    507 
    508     def full(self, X, Xs=None):
--> 509         X, Xs = self._slice(X, Xs)
    510         r = self.euclidean_dist(X, Xs)
    511         return (1.0 + np.sqrt(3.0) * r) * at.exp(-np.sqrt(3.0) * r)

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/gp/cov.py in _slice(self, X, Xs)
     95         xdims = X.shape[-1]
     96         if isinstance(xdims, Variable):
---> 97             xdims = xdims.eval()
     98         if self.input_dim != xdims:
     99             warnings.warn(

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/basic.py in eval(self, inputs_to_values)
    597         inputs = tuple(sorted(inputs_to_values.keys(), key=id))
    598         if inputs not in self._fn_cache:
--> 599             self._fn_cache[inputs] = function(inputs, self)
    600         args = [inputs_to_values[param] for param in inputs]
    601 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/__init__.py in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    315         # note: pfunc will also call orig_function -- orig_function is
    316         #      a choke point that all compilation must pass through
--> 317         fn = pfunc(
    318             params=inputs,
    319             outputs=outputs,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/pfunc.py in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    372     )
    373 
--> 374     return orig_function(
    375         inputs,
    376         cloned_outputs,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/types.py in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1749     try:
   1750         Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751         m = Maker(
   1752             inputs,
   1753             outputs,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/types.py in __init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
   1507         indices = [[input, None, [input]] for input in inputs]
   1508 
-> 1509         fgraph, found_updates = std_fgraph(
   1510             inputs, outputs, accept_inplace, fgraph=fgraph
   1511         )

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/compile/function/types.py in std_fgraph(input_specs, output_specs, accept_inplace, fgraph, features, force_clone)
    228             clone |= spec.variable.owner is not None
    229 
--> 230         fgraph = FunctionGraph(
    231             input_vars,
    232             [spec.variable for spec in output_specs] + updates,

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in __init__(self, inputs, outputs, features, clone, update_mapping, **clone_kwds)
    151 
    152         for output in outputs:
--> 153             self.add_output(output, reason="init")
    154 
    155         self.profile = None

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in add_output(self, var, reason, import_missing)
    161         """Add a new variable as an output to this `FunctionGraph`."""
    162         self.outputs.append(var)
--> 163         self.import_var(var, reason=reason, import_missing=import_missing)
    164         self.clients[var].append(("output", len(self.outputs) - 1))
    165 

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in import_var(self, var, reason, import_missing)
    302         # Imports the owners of the variables
    303         if var.owner and var.owner not in self.apply_nodes:
--> 304             self.import_node(var.owner, reason=reason, import_missing=import_missing)
    305         elif (
    306             var.owner is None

~/anaconda3/envs/pymc3/lib/python3.9/site-packages/aesara/graph/fg.py in import_node(self, apply_node, check, reason, import_missing)
    367                                 "for more information on this error."
    368                             )
--> 369                             raise MissingInputError(error_msg, variable=var)
    370 
    371         for node in new_nodes:

MissingInputError: Input 0 (xu) of the graph (indices start from 0), used to compute Shape(xu), was not provided and not given a value. Use the Aesara flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3169, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3361, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_33029/1554036267.py", line 9, in <module>
    xu = pm.Flat("xu", shape=(5, 1), initval=xu_init[:,None])
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/continuous.py", line 364, in __new__
    return super().__new__(cls, *args, **kwargs)
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/distributions/distribution.py", line 271, in __new__
    rv_out = model.register_rv(
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py", line 1359, in register_rv
    self.create_value_var(rv_var, transform)
  File "/home/spsantos/anaconda3/envs/pymc3/lib/python3.9/site-packages/pymc/model.py", line 1509, in create_value_var
    value_var = rv_var.type()

I don't really understand this error, since it's not really clear to me when or how the shape information is propagated through either the Aesara variables or PyMC wrappers. So I'm not sure if this error has the same underlying cause as the other or not.

I will say that this regression is rather disappointing, since being unable to tune either the hyperparameters or the inducing points makes MarginalApprox much less useful.

Versions and main components

  • PyMC/PyMC3 Version: 4.0.1
  • Aesara/Theano Version: 2.7.3
  • Python Version: 3.9.5
  • Operating system: Linux
  • How did you install PyMC/PyMC3: conda

Metadata

Metadata

Assignees

No one assigned

    Labels

    GPGaussian Processbug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions