Skip to content

Commit 2dc7591

Browse files
RuneDominikbrandonwillard
authored andcommitted
Add scipy's owens_t function as op
1 parent fd50f36 commit 2dc7591

File tree

4 files changed

+90
-0
lines changed

4 files changed

+90
-0
lines changed

aesara/scalar/math.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,35 @@ def c_code(self, node, name, inp, out, sub):
250250
erfcinv = Erfcinv(upgrade_to_float_no_complex, name="erfcinv")
251251

252252

253+
class Owens_t(BinaryScalarOp):
254+
nfunc_spec = ("scipy.special.owens_t", 2, 1)
255+
256+
@staticmethod
257+
def st_impl(h, a):
258+
return scipy.special.owens_t(h, a)
259+
260+
def impl(self, h, a):
261+
return Owens_t.st_impl(h, a)
262+
263+
def grad(self, inputs, grads):
264+
(h, a) = inputs
265+
(gz,) = grads
266+
return [
267+
gz
268+
* (-1)
269+
* exp(-(h**2) / 2)
270+
* erf(a * h / np.sqrt(2))
271+
/ (2 * np.sqrt(2 * np.pi)),
272+
gz * exp(-0.5 * (a**2 + 1) * h**2) / (2 * np.pi * (a**2 + 1)),
273+
]
274+
275+
def c_code(self, *args, **kwargs):
276+
raise NotImplementedError()
277+
278+
279+
owens_t = Owens_t(upgrade_to_float, name="owens_t")
280+
281+
253282
class Gamma(UnaryScalarOp):
254283
nfunc_spec = ("scipy.special.gamma", 1, 1)
255284

aesara/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def erfcx_inplace(a):
233233
"""scaled complementary error function"""
234234

235235

236+
@scalar_elemwise
237+
def owens_t_inplace(h, a):
238+
"""owens t function"""
239+
240+
236241
@scalar_elemwise
237242
def gamma_inplace(a):
238243
"""gamma function"""

aesara/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,11 @@ def erfcinv(a):
13391339
"""inverse complementary error function"""
13401340

13411341

1342+
@scalar_elemwise
1343+
def owens_t(h, a):
1344+
"""owens t function"""
1345+
1346+
13421347
@scalar_elemwise
13431348
def gamma(a):
13441349
"""gamma function"""
@@ -3062,6 +3067,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
30623067
"erfcx",
30633068
"erfinv",
30643069
"erfcinv",
3070+
"owens_t",
30653071
"gamma",
30663072
"gammaln",
30673073
"psi",

tests/tensor/test_math_scipy.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def scipy_special_gammal(k, x):
5353
expected_erfc = scipy.special.erfc
5454
expected_erfinv = scipy.special.erfinv
5555
expected_erfcinv = scipy.special.erfcinv
56+
expected_owenst = scipy.special.owens_t
5657
expected_gamma = scipy.special.gamma
5758
expected_gammaln = scipy.special.gammaln
5859
expected_psi = scipy.special.psi
@@ -146,6 +147,55 @@ def scipy_special_gammal(k, x):
146147
mode=mode_no_scipy,
147148
)
148149

150+
rng = np.random.default_rng(seed=utt.fetch_seed())
151+
_good_broadcast_binary_owenst = dict(
152+
normal=(
153+
random_ranged(-5, 5, (2, 3), rng=rng),
154+
random_ranged(-5, 5, (2, 3), rng=rng),
155+
),
156+
empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)),
157+
int=(
158+
integers_ranged(-5, 5, (2, 3), rng=rng),
159+
integers_ranged(-5, 5, (2, 3), rng=rng),
160+
),
161+
uint8=(
162+
integers_ranged(1, 6, (2, 3), rng=rng).astype("uint8"),
163+
integers_ranged(1, 6, (2, 3), rng=rng).astype("uint8"),
164+
),
165+
uint16=(
166+
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint16"),
167+
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint16"),
168+
),
169+
uint64=(
170+
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint64"),
171+
integers_ranged(1, 10, (2, 3), rng=rng).astype("uint64"),
172+
),
173+
)
174+
175+
_grad_broadcast_binary_owenst = dict(
176+
normal=(
177+
random_ranged(-5, 5, (2, 3), rng=rng),
178+
random_ranged(-5, 5, (2, 3), rng=rng),
179+
)
180+
)
181+
182+
TestOwensTBroadcast = makeBroadcastTester(
183+
op=at.owens_t,
184+
expected=expected_owenst,
185+
good=_good_broadcast_binary_owenst,
186+
grad=_grad_broadcast_binary_owenst,
187+
eps=2e-10,
188+
mode=mode_no_scipy,
189+
)
190+
TestOwensTInplaceBroadcast = makeBroadcastTester(
191+
op=inplace.owens_t_inplace,
192+
expected=expected_owenst,
193+
good=_good_broadcast_binary_owenst,
194+
eps=2e-10,
195+
mode=mode_no_scipy,
196+
inplace=True,
197+
)
198+
149199
rng = np.random.default_rng(seed=utt.fetch_seed())
150200
_good_broadcast_unary_gammaln = dict(
151201
normal=(random_ranged(-1 + 1e-2, 10, (2, 3), rng=rng),),

0 commit comments

Comments
 (0)