Skip to content

Commit 4fc7dc9

Browse files
committed
Small fixes to pm.Bound
* Fix invalid code example in docstrings * Rename distribution parameter to dist * Use `check_dist_not_registered` helper
1 parent 2cdf282 commit 4fc7dc9

File tree

2 files changed

+26
-36
lines changed

2 files changed

+26
-36
lines changed

pymc/distributions/bound.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc.distributions.logprob import logp
2727
from pymc.distributions.shape_utils import to_tuple
2828
from pymc.model import modelcontext
29+
from pymc.util import check_dist_not_registered
2930

3031
__all__ = ["Bound"]
3132

@@ -144,8 +145,9 @@ class Bound:
144145
145146
Parameters
146147
----------
147-
distribution: pymc distribution
148-
Distribution to be transformed into a bounded distribution.
148+
dist: PyMC unnamed distribution
149+
Distribution to be transformed into a bounded distribution created via the
150+
`.dist()` API.
149151
lower: float or array like, optional
150152
Lower bound of the distribution.
151153
upper: float or array like, optional
@@ -156,15 +158,15 @@ class Bound:
156158
.. code-block:: python
157159
158160
with pm.Model():
159-
normal_dist = Normal.dist(mu=0.0, sigma=1.0, initval=-0.5)
160-
negative_normal = pm.Bound(normal_dist, upper=0.0)
161+
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
162+
negative_normal = pm.Bound("negative_normal", normal_dist, upper=0.0)
161163
162164
"""
163165

164166
def __new__(
165167
cls,
166168
name,
167-
distribution,
169+
dist,
168170
lower=None,
169171
upper=None,
170172
size=None,
@@ -174,7 +176,7 @@ def __new__(
174176
**kwargs,
175177
):
176178

177-
cls._argument_checks(distribution, **kwargs)
179+
cls._argument_checks(dist, **kwargs)
178180

179181
if dims is not None:
180182
model = modelcontext(None)
@@ -185,12 +187,12 @@ def __new__(
185187
raise ValueError("Given dims do not exist in model coordinates.")
186188

187189
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
188-
distribution.tag.ignore_logprob = True
190+
dist.tag.ignore_logprob = True
189191

190-
if isinstance(distribution.owner.op, Continuous):
192+
if isinstance(dist.owner.op, Continuous):
191193
res = _ContinuousBounded(
192194
name,
193-
[distribution, lower, upper],
195+
[dist, lower, upper],
194196
initval=floatX(initval),
195197
size=size,
196198
shape=shape,
@@ -199,7 +201,7 @@ def __new__(
199201
else:
200202
res = _DiscreteBounded(
201203
name,
202-
[distribution, lower, upper],
204+
[dist, lower, upper],
203205
initval=intX(initval),
204206
size=size,
205207
shape=shape,
@@ -210,28 +212,28 @@ def __new__(
210212
@classmethod
211213
def dist(
212214
cls,
213-
distribution,
215+
dist,
214216
lower=None,
215217
upper=None,
216218
size=None,
217219
shape=None,
218220
**kwargs,
219221
):
220222

221-
cls._argument_checks(distribution, **kwargs)
223+
cls._argument_checks(dist, **kwargs)
222224
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
223-
distribution.tag.ignore_logprob = True
224-
if isinstance(distribution.owner.op, Continuous):
225+
dist.tag.ignore_logprob = True
226+
if isinstance(dist.owner.op, Continuous):
225227
res = _ContinuousBounded.dist(
226-
[distribution, lower, upper],
228+
[dist, lower, upper],
227229
size=size,
228230
shape=shape,
229231
**kwargs,
230232
)
231233
res.tag.test_value = floatX(initval)
232234
else:
233235
res = _DiscreteBounded.dist(
234-
[distribution, lower, upper],
236+
[dist, lower, upper],
235237
size=size,
236238
shape=shape,
237239
**kwargs,
@@ -240,7 +242,7 @@ def dist(
240242
return res
241243

242244
@classmethod
243-
def _argument_checks(cls, distribution, **kwargs):
245+
def _argument_checks(cls, dist, **kwargs):
244246
if "observed" in kwargs:
245247
raise ValueError(
246248
"Observed Bound distributions are not supported. "
@@ -249,34 +251,22 @@ def _argument_checks(cls, distribution, **kwargs):
249251
"with the cumulative probability function."
250252
)
251253

252-
if not isinstance(distribution, TensorVariable):
254+
if not isinstance(dist, TensorVariable):
253255
raise ValueError(
254256
"Passing a distribution class to `Bound` is no longer supported.\n"
255257
"Please pass the output of a distribution instantiated via the "
256258
"`.dist()` API such as:\n"
257259
'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
258260
)
259261

260-
try:
261-
model = modelcontext(None)
262-
except TypeError:
263-
pass
264-
else:
265-
if distribution in model.basic_RVs:
266-
raise ValueError(
267-
f"The distribution passed into `Bound` was already registered "
268-
f"in the current model.\nYou should pass an unregistered "
269-
f"(unnamed) distribution created via the `.dist()` API, such as:\n"
270-
f'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
271-
)
272-
273-
if distribution.owner.op.ndim_supp != 0:
262+
check_dist_not_registered(dist)
263+
264+
if dist.owner.op.ndim_supp != 0:
274265
raise NotImplementedError("Bounding of MultiVariate RVs is not yet supported.")
275266

276-
if not isinstance(distribution.owner.op, (Discrete, Continuous)):
267+
if not isinstance(dist.owner.op, (Discrete, Continuous)):
277268
raise ValueError(
278-
f"`distribution` {distribution} must be a Discrete or Continuous"
279-
" distribution subclass"
269+
f"`distribution` {dist} must be a Discrete or Continuous" " distribution subclass"
280270
)
281271

282272
@classmethod

pymc/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2701,7 +2701,7 @@ def test_arguments_checks(self):
27012701
with pytest.raises(ValueError, match=msg):
27022702
pm.Bound("bound", x, dims="random_dims")
27032703

2704-
msg = "The distribution passed into `Bound` was already registered"
2704+
msg = "The dist x was already registered in the current model"
27052705
with pm.Model() as m:
27062706
x = pm.Normal("x", 0, 1)
27072707
with pytest.raises(ValueError, match=msg):

0 commit comments

Comments
 (0)