Skip to content

Commit 230a808

Browse files
committed
Add functional vectorize helper to pytensor.tensor module
1 parent 2778160 commit 230a808

File tree

6 files changed

+252
-42
lines changed

6 files changed

+252
-42
lines changed

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
148148
from pytensor.tensor.type import * # noqa
149149
from pytensor.tensor.type_other import * # noqa
150150
from pytensor.tensor.variable import TensorConstant, TensorVariable # noqa
151+
from pytensor.tensor.functional import vectorize # noqa
151152

152153
# Allow accessing numpy constants from pytensor.tensor
153154
from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi # noqa

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from collections.abc import Sequence
32
from typing import Any, Optional, cast
43

@@ -13,49 +12,14 @@
1312
from pytensor.tensor import as_tensor_variable
1413
from pytensor.tensor.shape import shape_padleft
1514
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
16-
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string
15+
from pytensor.tensor.utils import (
16+
_parse_gufunc_signature,
17+
broadcast_static_dim_lengths,
18+
import_func_from_string,
19+
)
1720
from pytensor.tensor.variable import TensorVariable
1821

1922

20-
# TODO: Implement vectorize helper to batch whole graphs (similar to what Blockwise does for the grad)
21-
22-
# Copied verbatim from numpy.lib.function_base
23-
# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029
24-
_DIMENSION_NAME = r"\w+"
25-
_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME)
26-
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
27-
_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT)
28-
_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST)
29-
30-
31-
def _parse_gufunc_signature(signature):
32-
"""
33-
Parse string signatures for a generalized universal function.
34-
35-
Arguments
36-
---------
37-
signature : string
38-
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
39-
for ``np.matmul``.
40-
41-
Returns
42-
-------
43-
Tuple of input and output core dimensions parsed from the signature, each
44-
of the form List[Tuple[str, ...]].
45-
"""
46-
signature = re.sub(r"\s+", "", signature)
47-
48-
if not re.match(_SIGNATURE, signature):
49-
raise ValueError(f"not a valid gufunc signature: {signature}")
50-
return tuple(
51-
[
52-
tuple(re.findall(_DIMENSION_NAME, arg))
53-
for arg in re.findall(_ARGUMENT, arg_list)
54-
]
55-
for arg_list in signature.split("->")
56-
)
57-
58-
5923
def safe_signature(
6024
core_inputs: Sequence[Variable],
6125
core_outputs: Sequence[Variable],

pytensor/tensor/functional.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import Callable, Optional
2+
3+
from pytensor.graph import vectorize_graph
4+
from pytensor.tensor import TensorVariable
5+
from pytensor.tensor.utils import _parse_gufunc_signature
6+
7+
8+
def vectorize(func: Callable, signature: Optional[str] = None) -> Callable:
9+
"""Create a vectorized version of a python function that takes TensorVariables as inputs and outputs.
10+
11+
Similar to numpy.vectorize. See respective docstrings for more details.
12+
13+
Parameters
14+
----------
15+
func: Callable
16+
Function that creates the desired outputs from TensorVariable inputs with the core dimensions.
17+
signature: str, optional
18+
Generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication.
19+
If not provided, it is assumed all inputs have scalar core dimensions. Unlike numpy, the outputs
20+
can have arbitrary shapes when the signature is not provided.
21+
22+
Returns
23+
-------
24+
vectorized_func: Callable
25+
Callable that takes TensorVariables with arbitrarily batched dimensions on the left
26+
and returns variables whose graphs correspond to the vectorized expressions of func.
27+
28+
Notes
29+
-----
30+
Unlike numpy.vectorize, the equality of core dimensions implied by the signature is not explicitly asserted.
31+
32+
To vectorize an existing graph, use `pytensor.graph.replace.vectorize_graph` instead.
33+
34+
35+
Examples
36+
--------
37+
.. code-block:: python
38+
39+
import pytensor
40+
import pytensor.tensor as pt
41+
42+
def func(x):
43+
return pt.exp(x) / pt.sum(pt.exp(x))
44+
45+
vec_func = pt.vectorize(func, signature="(a)->(a)")
46+
47+
x = pt.matrix("x")
48+
y = vec_func(x)
49+
50+
fn = pytensor.function([x], y)
51+
fn([[0, 1, 2], [2, 1, 0]])
52+
# array([[0.09003057, 0.24472847, 0.66524096],
53+
# [0.66524096, 0.24472847, 0.09003057]])
54+
55+
56+
.. code-block:: python
57+
58+
import pytensor
59+
import pytensor.tensor as pt
60+
61+
def func(x):
62+
return x[0], x[-1]
63+
64+
vec_func = pt.vectorize(func, signature="(a)->(),()")
65+
66+
x = pt.matrix("x")
67+
y1, y2 = vec_func(x)
68+
69+
fn = pytensor.function([x], [y1, y2])
70+
fn([[-10, 0, 10], [-11, 0, 11]])
71+
# [array([-10., -11.]), array([10., 11.])]
72+
73+
"""
74+
75+
def inner(*inputs):
76+
if signature is None:
77+
# Assume all inputs are scalar
78+
inputs_sig = [()] * len(inputs)
79+
else:
80+
inputs_sig, outputs_sig = _parse_gufunc_signature(signature)
81+
if len(inputs) != len(inputs_sig):
82+
raise ValueError(
83+
f"Number of inputs does not match signature: {signature}"
84+
)
85+
86+
# Create dummy core inputs by stripping the batched dimensions of inputs
87+
core_inputs = []
88+
for input, input_sig in zip(inputs, inputs_sig):
89+
if not isinstance(input, TensorVariable):
90+
raise TypeError(
91+
f"Inputs to vectorize function must be TensorVariable, got {type(input)}"
92+
)
93+
94+
if input.ndim < len(input_sig):
95+
raise ValueError(
96+
f"Input {input} has less dimensions than signature {input_sig}"
97+
)
98+
if len(input_sig):
99+
core_shape = input.type.shape[-len(input_sig) :]
100+
else:
101+
core_shape = ()
102+
103+
core_input = input.type.clone(shape=core_shape)(name=input.name)
104+
core_inputs.append(core_input)
105+
106+
# Call function on dummy core inputs
107+
core_outputs = func(*core_inputs)
108+
if core_outputs is None:
109+
raise ValueError("vectorize function returned no outputs")
110+
111+
if signature is not None:
112+
if isinstance(core_outputs, (list, tuple)):
113+
n_core_outputs = len(core_outputs)
114+
else:
115+
n_core_outputs = 1
116+
if n_core_outputs != len(outputs_sig):
117+
raise ValueError(
118+
f"Number of outputs does not match signature: {signature}"
119+
)
120+
121+
# Vectorize graph by replacing dummy core inputs by original inputs
122+
outputs = vectorize_graph(core_outputs, replace=dict(zip(core_inputs, inputs)))
123+
return outputs
124+
125+
return inner

pytensor/tensor/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from collections.abc import Sequence
23
from typing import Union
34

@@ -161,3 +162,40 @@ def broadcast_static_dim_lengths(
161162
if len(dim_lengths_set) > 1:
162163
raise ValueError
163164
return tuple(dim_lengths_set)[0]
165+
166+
167+
# Copied verbatim from numpy.lib.function_base
168+
# https://github.com/numpy/numpy/blob/f2db090eb95b87d48a3318c9a3f9d38b67b0543c/numpy/lib/function_base.py#L1999-L2029
169+
_DIMENSION_NAME = r"\w+"
170+
_CORE_DIMENSION_LIST = "(?:{0:}(?:,{0:})*)?".format(_DIMENSION_NAME)
171+
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
172+
_ARGUMENT_LIST = "{0:}(?:,{0:})*".format(_ARGUMENT)
173+
_SIGNATURE = "^{0:}->{0:}$".format(_ARGUMENT_LIST)
174+
175+
176+
def _parse_gufunc_signature(signature):
177+
"""
178+
Parse string signatures for a generalized universal function.
179+
180+
Arguments
181+
---------
182+
signature : string
183+
Generalized universal function signature, e.g., ``(m,n),(n,p)->(m,p)``
184+
for ``np.matmul``.
185+
186+
Returns
187+
-------
188+
Tuple of input and output core dimensions parsed from the signature, each
189+
of the form List[Tuple[str, ...]].
190+
"""
191+
signature = re.sub(r"\s+", "", signature)
192+
193+
if not re.match(_SIGNATURE, signature):
194+
raise ValueError(f"not a valid gufunc signature: {signature}")
195+
return tuple(
196+
[
197+
tuple(re.findall(_DIMENSION_NAME, arg))
198+
for arg in re.findall(_ARGUMENT, arg_list)
199+
]
200+
for arg_list in signature.split("->")
201+
)

tests/tensor/test_blockwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from pytensor.graph import Apply, Op
1111
from pytensor.graph.replace import vectorize_node
1212
from pytensor.tensor import diagonal, log, tensor
13-
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
13+
from pytensor.tensor.blockwise import Blockwise
1414
from pytensor.tensor.nlinalg import MatrixInverse
1515
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
16+
from pytensor.tensor.utils import _parse_gufunc_signature
1617

1718

1819
def test_vectorize_blockwise():

tests/tensor/test_functional.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.graph.basic import equal_computations
5+
from pytensor.tensor import full, tensor
6+
from pytensor.tensor.functional import vectorize
7+
from pytensor.tensor.random.type import RandomGeneratorType
8+
9+
10+
class TestVectorize:
11+
def test_vectorize_no_signature(self):
12+
"""Unlike numpy we don't assume outputs of vectorize without signature are scalar."""
13+
14+
def func(x):
15+
return full((5, 3), x)
16+
17+
vec_func = vectorize(func)
18+
19+
x = tensor("x", shape=(4,), dtype="float64")
20+
out = vec_func(x)
21+
22+
assert out.type.ndim == 3
23+
test_x = np.array([1, 2, 3, 4])
24+
np.testing.assert_allclose(
25+
out.eval({x: test_x}), np.full((len(test_x), 5, 3), test_x[:, None, None])
26+
)
27+
28+
def test_vectorize_outer_product(self):
29+
def func(x, y):
30+
return x[:, None] * y[None, :]
31+
32+
vec_func = vectorize(func, signature="(a),(b)->(a,b)")
33+
34+
x = tensor("x", shape=(2, 3, 5))
35+
y = tensor("y", shape=(2, 3, 7))
36+
out = vec_func(x, y)
37+
38+
assert out.type.shape == (2, 3, 5, 7)
39+
assert equal_computations([out], [x[..., :, None] * y[..., None, :]])
40+
41+
def test_vectorize_outer_inner_product(self):
42+
def func(x, y):
43+
return x[:, None] * y[None, :], (x * y).sum()
44+
45+
vec_func = vectorize(func, signature="(a),(b)->(a,b),()")
46+
47+
x = tensor("x", shape=(2, 3, 5))
48+
y = tensor("y", shape=(2, 3, 5))
49+
outer, inner = vec_func(x, y)
50+
51+
assert outer.type.shape == (2, 3, 5, 5)
52+
assert inner.type.shape == (2, 3)
53+
assert equal_computations([outer], [x[..., :, None] * y[..., None, :]])
54+
assert equal_computations([inner], [(x * y).sum(axis=-1)])
55+
56+
def test_errors(self):
57+
def func(x, y):
58+
return x + y, x - y
59+
60+
x = tensor("x", shape=(5,))
61+
y = tensor("y", shape=())
62+
63+
with pytest.raises(ValueError, match="Number of inputs"):
64+
vectorize(func, signature="(),()->()")(x)
65+
66+
with pytest.raises(ValueError, match="Number of outputs"):
67+
vectorize(func, signature="(),()->()")(x, y)
68+
69+
with pytest.raises(ValueError, match="Input y has less dimensions"):
70+
vectorize(func, signature="(a),(a)->(a),(a)")(x, y)
71+
72+
bad_input = RandomGeneratorType()
73+
74+
with pytest.raises(TypeError, match="must be TensorVariable"):
75+
vectorize(func)(bad_input, x)
76+
77+
def bad_func(x, y):
78+
x + y
79+
80+
with pytest.raises(ValueError, match="no outputs"):
81+
vectorize(bad_func)(x, y)

0 commit comments

Comments
 (0)