26
26
from pymc .distributions .logprob import logp
27
27
from pymc .distributions .shape_utils import to_tuple
28
28
from pymc .model import modelcontext
29
+ from pymc .util import check_dist_not_registered
29
30
30
31
__all__ = ["Bound" ]
31
32
@@ -144,8 +145,9 @@ class Bound:
144
145
145
146
Parameters
146
147
----------
147
- distribution: pymc distribution
148
- Distribution to be transformed into a bounded distribution.
148
+ dist: PyMC unnamed distribution
149
+ Distribution to be transformed into a bounded distribution created via the
150
+ `.dist()` API.
149
151
lower: float or array like, optional
150
152
Lower bound of the distribution.
151
153
upper: float or array like, optional
@@ -156,15 +158,15 @@ class Bound:
156
158
.. code-block:: python
157
159
158
160
with pm.Model():
159
- normal_dist = Normal.dist(mu=0.0, sigma=1.0, initval=-0.5 )
160
- negative_normal = pm.Bound(normal_dist, upper=0.0)
161
+ normal_dist = pm. Normal.dist(mu=0.0, sigma=1.0)
162
+ negative_normal = pm.Bound("negative_normal", normal_dist, upper=0.0)
161
163
162
164
"""
163
165
164
166
def __new__ (
165
167
cls ,
166
168
name ,
167
- distribution ,
169
+ dist ,
168
170
lower = None ,
169
171
upper = None ,
170
172
size = None ,
@@ -174,7 +176,7 @@ def __new__(
174
176
** kwargs ,
175
177
):
176
178
177
- cls ._argument_checks (distribution , ** kwargs )
179
+ cls ._argument_checks (dist , ** kwargs )
178
180
179
181
if dims is not None :
180
182
model = modelcontext (None )
@@ -185,12 +187,12 @@ def __new__(
185
187
raise ValueError ("Given dims do not exist in model coordinates." )
186
188
187
189
lower , upper , initval = cls ._set_values (lower , upper , size , shape , initval )
188
- distribution .tag .ignore_logprob = True
190
+ dist .tag .ignore_logprob = True
189
191
190
- if isinstance (distribution .owner .op , Continuous ):
192
+ if isinstance (dist .owner .op , Continuous ):
191
193
res = _ContinuousBounded (
192
194
name ,
193
- [distribution , lower , upper ],
195
+ [dist , lower , upper ],
194
196
initval = floatX (initval ),
195
197
size = size ,
196
198
shape = shape ,
@@ -199,7 +201,7 @@ def __new__(
199
201
else :
200
202
res = _DiscreteBounded (
201
203
name ,
202
- [distribution , lower , upper ],
204
+ [dist , lower , upper ],
203
205
initval = intX (initval ),
204
206
size = size ,
205
207
shape = shape ,
@@ -210,28 +212,28 @@ def __new__(
210
212
@classmethod
211
213
def dist (
212
214
cls ,
213
- distribution ,
215
+ dist ,
214
216
lower = None ,
215
217
upper = None ,
216
218
size = None ,
217
219
shape = None ,
218
220
** kwargs ,
219
221
):
220
222
221
- cls ._argument_checks (distribution , ** kwargs )
223
+ cls ._argument_checks (dist , ** kwargs )
222
224
lower , upper , initval = cls ._set_values (lower , upper , size , shape , initval = None )
223
- distribution .tag .ignore_logprob = True
224
- if isinstance (distribution .owner .op , Continuous ):
225
+ dist .tag .ignore_logprob = True
226
+ if isinstance (dist .owner .op , Continuous ):
225
227
res = _ContinuousBounded .dist (
226
- [distribution , lower , upper ],
228
+ [dist , lower , upper ],
227
229
size = size ,
228
230
shape = shape ,
229
231
** kwargs ,
230
232
)
231
233
res .tag .test_value = floatX (initval )
232
234
else :
233
235
res = _DiscreteBounded .dist (
234
- [distribution , lower , upper ],
236
+ [dist , lower , upper ],
235
237
size = size ,
236
238
shape = shape ,
237
239
** kwargs ,
@@ -240,7 +242,7 @@ def dist(
240
242
return res
241
243
242
244
@classmethod
243
- def _argument_checks (cls , distribution , ** kwargs ):
245
+ def _argument_checks (cls , dist , ** kwargs ):
244
246
if "observed" in kwargs :
245
247
raise ValueError (
246
248
"Observed Bound distributions are not supported. "
@@ -249,34 +251,22 @@ def _argument_checks(cls, distribution, **kwargs):
249
251
"with the cumulative probability function."
250
252
)
251
253
252
- if not isinstance (distribution , TensorVariable ):
254
+ if not isinstance (dist , TensorVariable ):
253
255
raise ValueError (
254
256
"Passing a distribution class to `Bound` is no longer supported.\n "
255
257
"Please pass the output of a distribution instantiated via the "
256
258
"`.dist()` API such as:\n "
257
259
'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
258
260
)
259
261
260
- try :
261
- model = modelcontext (None )
262
- except TypeError :
263
- pass
264
- else :
265
- if distribution in model .basic_RVs :
266
- raise ValueError (
267
- f"The distribution passed into `Bound` was already registered "
268
- f"in the current model.\n You should pass an unregistered "
269
- f"(unnamed) distribution created via the `.dist()` API, such as:\n "
270
- f'`pm.Bound("bound", pm.Normal.dist(0, 1), lower=0)`'
271
- )
272
-
273
- if distribution .owner .op .ndim_supp != 0 :
262
+ check_dist_not_registered (dist )
263
+
264
+ if dist .owner .op .ndim_supp != 0 :
274
265
raise NotImplementedError ("Bounding of MultiVariate RVs is not yet supported." )
275
266
276
- if not isinstance (distribution .owner .op , (Discrete , Continuous )):
267
+ if not isinstance (dist .owner .op , (Discrete , Continuous )):
277
268
raise ValueError (
278
- f"`distribution` { distribution } must be a Discrete or Continuous"
279
- " distribution subclass"
269
+ f"`distribution` { dist } must be a Discrete or Continuous" " distribution subclass"
280
270
)
281
271
282
272
@classmethod
0 commit comments