Skip to content

Update clone_replace strict keyword name #5849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 17, 2022

Conversation

brandonwillard
Copy link
Contributor

This PR updates the use of clone_replace with the strict keyword to use the rebuild_strict keyword.

@ricardoV94
Copy link
Member

You might need to update this line for the pre-commit:

- aesara==2.6.6

@brandonwillard brandonwillard force-pushed the fix-clone_replace-calls branch from b9fb0b2 to 1145527 Compare June 3, 2022 15:43
@brandonwillard brandonwillard removed the bug label Jun 3, 2022
@codecov
Copy link

codecov bot commented Jun 3, 2022

Codecov Report

Merging #5849 (58ce5ce) into main (938604c) will increase coverage by 3.39%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5849      +/-   ##
==========================================
+ Coverage   86.10%   89.49%   +3.39%     
==========================================
  Files          73       73              
  Lines       13225    13267      +42     
==========================================
+ Hits        11387    11873     +486     
+ Misses       1838     1394     -444     
Impacted Files Coverage Δ
pymc/aesaraf.py 91.95% <100.00%> (ø)
pymc/data.py 81.63% <100.00%> (ø)
pymc/smc/smc.py 96.45% <100.00%> (ø)
pymc/step_methods/compound.py 86.66% <0.00%> (-6.67%) ⬇️
pymc/step_methods/hmc/base_hmc.py 89.68% <0.00%> (-0.80%) ⬇️
pymc/parallel_sampling.py 85.80% <0.00%> (-0.67%) ⬇️
pymc/backends/ndarray.py 79.46% <0.00%> (-0.19%) ⬇️
pymc/distributions/__init__.py 100.00% <0.00%> (ø)
pymc/distributions/logprob.py 97.72% <0.00%> (+0.07%) ⬆️
pymc/variational/approximations.py 86.69% <0.00%> (+0.18%) ⬆️
... and 8 more

@brandonwillard
Copy link
Contributor Author

The current failures are fixed by aesara-devs/aesara#976.

@brandonwillard brandonwillard force-pushed the fix-clone_replace-calls branch from 1145527 to fec2ba8 Compare June 4, 2022 12:13
@twiecki
Copy link
Member

twiecki commented Jun 6, 2022


pymc/tests/test_distributions_timeseries.py:460: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
pymc/tests/test_distributions_moments.py:157: in assert_moment_is_expected
    random_draw = model["x"].eval()
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/graph/basic.py:590: in eval
    self._fn_cache[inputs] = function(inputs, self)
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/compile/function/__init__.py:330: in function
    output_keys=output_keys,
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/compile/function/pfunc.py:383: in pfunc
    fgraph=fgraph,
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/compile/function/types.py:1760: in orig_function
    fgraph=fgraph,
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/compile/function/types.py:1522: in __init__
    inputs, outputs, found_updates, fgraph, optimizer, linker, profile
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/compile/function/types.py:1411: in prepare_fgraph
    optimizer_profile = optimizer(fgraph)
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/graph/opt.py:111: in __call__
    return self.optimize(fgraph)
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/graph/opt.py:102: in optimize
    ret = self.apply(fgraph, *args, **kwargs)
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/graph/opt.py:290: in apply
    self.failure_callback(e, self, optimizer)
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/graph/opt.py:225: in warn
    raise exc
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/graph/opt.py:279: in apply
    sub_prof = optimizer.apply(fgraph)
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/scan/opt.py:1083: in apply
    fgraph, scan_nodes[scan_idx], out_indices, alloc_ops
/usr/share/miniconda3/envs/pymc-test-py37/lib/python3.7/site-packages/aesara/scan/opt.py:1009: in attempt_scan_inplace
    allow_gc=op.allow_gc,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = <aesara.scan.op.Scan object at 0x7f58a7767950>
inputs = [NominalTensorVariable(0, TensorType(float64, (None,))), NominalTensorVariable(1, TensorType(float64, (None,))), Nomin...rType), NominalTensorVariable(3, TensorType(float64, (None, None))), NominalTensorVariable(4, TensorType(float64, ()))]
outputs = [normal_rv{0, (0, 0), floatX, False}.out, normal_rv{0, (0, 0), floatX, False}.0]
info = ScanInfo(n_seqs=0, mit_mot_in_slices=(), mit_mot_out_slices=(), mit_sot_in_slices=((-2, -1),), sit_sot_in_slices=(), n_nit_sot=0, n_shared_outs=1, n_non_seqs=2, as_while=False)
mode = <aesara.compile.mode.Mode object at 0x7f58b1eb27d0>
typeConstructor = <function Scan.__init__.<locals>.tensorConstructor at 0x7f58a7d64c20>
truncate_gradient = -1, name = 'scan_fn', as_while = False, profile = False
allow_gc = False, strict = True
    def __init__(
        self,
        inputs: List[Variable],
        outputs: List[Variable],
        info: ScanInfo,
        mode: Optional[Mode] = None,
        typeConstructor: Optional[TensorConstructorType] = None,
        truncate_gradient: int = -1,
        name: Optional[str] = None,
        as_while: bool = False,
        profile: Optional[Union[str, bool]] = None,
        allow_gc: bool = True,
        strict: bool = True,
    ):
        r"""
        Parameters
        ----------
        inputs
            Inputs of the inner function of `Scan`.
            These take the following general form:
                sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences
            where each term is a list of `Variable`\s.
        outputs
            Outputs of the inner function of `Scan`.
            These take the following general form:
                mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs [+ while-condition]
            where each term is a list of `Variable`\s.
        info
            A collection of information about the sequences and taps.
        mode
            The mode used to compile the inner-graph.
            If you prefer the computations of one step of `scan` to be done
            differently then the entire function, you can use this parameter to
            describe how the computations in this loop are done (see
            `aesara.function` for details about possible values and their meaning).
        typeConstructor
            Function that constructs a `TensorType` for the outputs.
        truncate_gradient
            `truncate_gradient` is the number of steps to use in truncated
            back-propagation through time (BPTT).  If you compute gradients through
            a `Scan` `Op`, they are computed using BPTT. By providing a different
            value then ``-1``, you choose to use truncated BPTT instead of classical
            BPTT, where you go for only `truncate_gradient` number of steps back in
            time.
        name
            When profiling `scan`, it is helpful to provide a name for any
            instance of `scan`.
            For example, the profiler will produce an overall profile of your code
            as well as profiles for the computation of one step of each instance of
            `Scan`. The `name` of the instance appears in those profiles and can
            greatly help to disambiguate information.
        profile
            If ``True`` or a non-empty string, a profile object will be created and
            attached to the inner graph of `Scan`. When `profile` is ``True``, the
            profiler results will use the name of the `Scan` instance, otherwise it
            will use the passed string.  The profiler only collects and prints
            information when running the inner graph with the `CVM` `Linker`.
        allow_gc
            Set the value of `allow_gc` for the internal graph of the `Scan`.  If
            set to ``None``, this will use the value of
            `aesara.config.scan__allow_gc`.
            The full `Scan` behavior related to allocation is determined by this
            value and the flag `aesara.config.allow_gc`. If the flag
            `allow_gc` is ``True`` (default) and this `allow_gc` is ``False``
            (default), then we let `Scan` allocate all intermediate memory
            on the first iteration, and they are not garbage collected
            after that first iteration; this is determined by `allow_gc`. This can
            speed up allocation of the subsequent iterations. All those temporary
            allocations are freed at the end of all iterations; this is what the
            flag `aesara.config.allow_gc` means.
        strict
            If ``True``, all the shared variables used in the inner-graph must be provided.
        """
        inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
        input_replacements = []
        for n, v in enumerate(inputs):
            if not isinstance(v, (SharedVariable, Constant)):
                input_replacements.append((v, NominalVariable(n, v.type)))
            assert not isinstance(v, NominalVariable)
        outputs = clone_replace(outputs, replace=input_replacements)
        if input_replacements:
            _, inputs_ = zip(*input_replacements)
            inputs = list(inputs_)
        else:
            inputs = []
        self.info = info
        self.truncate_gradient = truncate_gradient
        self.name = name
        self.profile = profile
        self.allow_gc = allow_gc
        self.strict = strict
        # Clone mode_instance, altering "allow_gc" for the linker,
        # and adding a message if we profile
        if self.name:
            message = self.name + " sub profile"
        else:
            message = "Scan sub profile"
        self.mode = get_default_mode() if mode is None else mode
        self.mode_instance = get_mode(self.mode).clone(
            link_kwargs=dict(allow_gc=self.allow_gc), message=message
        )
        # build a list of output types for any Apply node using this op.
        self.output_types = []
        def tensorConstructor(shape, dtype):
            return TensorType(dtype=dtype, shape=shape)
        if typeConstructor is None:
            typeConstructor = tensorConstructor
        idx = 0
        jdx = 0
        while idx < info.n_mit_mot_outs:
            # Not that for mit_mot there are several output slices per
            # output sequence
            o = outputs[idx]
            self.output_types.append(
                typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
            )
            idx += len(info.mit_mot_out_slices[jdx])
            jdx += 1
        # mit_sot / sit_sot / nit_sot
        end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
        for o in outputs[idx:end]:
            self.output_types.append(
                typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
            )
        # shared outputs + possibly the ending condition
        for o in outputs[end:]:
            self.output_types.append(o.type)
        if info.as_while:
            self.output_types = self.output_types[:-1]
        if not hasattr(self, "name") or self.name is None:
            self.name = "scan_fn"
        # Pre-computing some values to speed up perform
        self.mintaps = [
            min(x)
            for x in chain(
                info.mit_mot_in_slices, info.mit_sot_in_slices, info.sit_sot_in_slices
            )
        ]
        self.mintaps += [0 for x in range(info.n_nit_sot)]
        self.seqs_arg_offset = 1 + info.n_seqs
        self.shared_arg_offset = (
            self.seqs_arg_offset + info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
        )
        self.nit_sot_arg_offset = self.shared_arg_offset + info.n_shared_outs
        # XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
        # of the number of outputs generated by taps with inputs
        self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
        self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
        # TODO: These can be moved to thunk/function compilation
        (
            _,
            self.mitmots_preallocated,
        ) = self._mitmot_preallocations()
        self.n_outer_inputs = info.n_outer_inputs
        self.n_outer_outputs = info.n_outer_outputs
        self.fgraph = FunctionGraph(inputs, outputs, clone=False)
        _ = self.prepare_fgraph(self.fgraph)
        if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
            raise InconsistencyError(
>               "Inner-graphs must not contain in-place operations."
            )
E           aesara.graph.utils.InconsistencyError: Inner-graphs must not contain in-place operations.

Any idea what this is about?

@brandonwillard
Copy link
Contributor Author

Any idea what this is about?

I'm looking into it now, but, so far, it appears to be another OpFromGraph compiled inner-graph issue.

@brandonwillard
Copy link
Contributor Author

OK, this error is a special one. It looks like the graph being compiled is an OpFromGraph that contains a Scan (i.e. AutoRegressiveRV), and its inner-graph is being optimized, compiled, and cached by the make_initial_point_fn call—which introduces some in-place Ops into the Scan node's inner-graph—then that cached and optimized graph is optimized again when Variable.eval is called, and that's what causes the error.

@twiecki
Copy link
Member

twiecki commented Jun 8, 2022

@brandonwillard I understood 20% of that, is it difficult to fix? Is the fix on the aseara or pymc side?

@brandonwillard
Copy link
Contributor Author

@brandonwillard I understood 20% of that, is it difficult to fix? Is the fix on the aseara or pymc side?

I don't think a fix would be all that involved. For instance, if we completely remove the "statefulness" of the inner-graph containing Ops (i.e. OpFromGraph and/or Scan), then we can easily avoid such issues.

I almost have a good MWE from which to work.

@ricardoV94
Copy link
Member

Is the error related to this merged PR aesara-devs/aesara#993?

@brandonwillard
Copy link
Contributor Author

brandonwillard commented Jun 14, 2022

Is the error related to this merged PR aesara-devs/aesara#993?

It is technically fixed by that PR, but the underlying problem is not completely addressed by it. It should be fine to move forward with that as a fix in the meantime, though.

@brandonwillard brandonwillard force-pushed the fix-clone_replace-calls branch from fec2ba8 to 860b713 Compare June 14, 2022 21:16
This is to avoid issues in upcoming Aesara releases that require the argument.
@ricardoV94
Copy link
Member

Read the docs is failing with Could not import extension jupyter_sphinx (exception: No module named 'jupyter_sphinx')

CC @OriolAbril

@OriolAbril
Copy link
Member

see #5845

@ricardoV94 ricardoV94 merged commit 3bf95ce into pymc-devs:main Jun 17, 2022
@brandonwillard brandonwillard deleted the fix-clone_replace-calls branch June 20, 2022 00:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants