Skip to content

Commit bafe249

Browse files
committed
Fix CDF and iCDF derivations based on monotonicity
1 parent 9a34943 commit bafe249

File tree

2 files changed

+144
-48
lines changed

2 files changed

+144
-48
lines changed

pymc/logprob/transforms.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,10 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
440440
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
441441

442442

443+
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf)
444+
MONOTONICALLY_DECREASING_OPS = (Erfc, Erfcx)
445+
446+
443447
@_logcdf.register(MeasurableTransform)
444448
def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs):
445449
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
@@ -453,12 +457,35 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
453457
if isinstance(backward_value, tuple):
454458
raise NotImplementedError
455459

456-
input_logcdf = _logcdf_helper(measurable_input, backward_value)
460+
logcdf = _logcdf_helper(measurable_input, backward_value)
461+
logccdf = pt.log1mexp(logcdf)
462+
463+
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
464+
pass
465+
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
466+
logcdf = logccdf
467+
# mul is monotonically increasing for scale > 0, and monotonically decreasing otherwise
468+
elif isinstance(op.scalar_op, Mul):
469+
[scale] = other_inputs
470+
logcdf = pt.switch(pt.ge(scale, 0), logcdf, logccdf)
471+
# pow is increasing if pow > 0, and decreasing otherwise (even powers are rejected above)!
472+
# Care must be taken to handle negative values (https://math.stackexchange.com/a/442362/783483)
473+
elif isinstance(op.scalar_op, Pow):
474+
if op.transform_elemwise.power < 0:
475+
logcdf_zero = _logcdf_helper(measurable_input, 0)
476+
logcdf = pt.switch(
477+
pt.lt(backward_value, 0),
478+
pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)),
479+
pt.logaddexp(logccdf, logcdf_zero),
480+
)
481+
else:
482+
# We don't know if this Op is monotonically increasing/decreasing
483+
raise NotImplementedError
457484

458485
# The jacobian is used to ensure a value in the supported domain was provided
459486
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
460487

461-
return pt.switch(pt.isnan(jacobian), -np.inf, input_logcdf)
488+
return pt.switch(pt.isnan(jacobian), -np.inf, logcdf)
462489

463490

464491
@_icdf.register(MeasurableTransform)
@@ -467,6 +494,19 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
467494
other_inputs = list(inputs)
468495
measurable_input = other_inputs.pop(op.measurable_input_idx)
469496

497+
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
498+
pass
499+
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
500+
value = 1 - value
501+
elif isinstance(op.scalar_op, Mul):
502+
[scale] = other_inputs
503+
value = pt.switch(pt.lt(scale, 0), 1 - value, value)
504+
elif isinstance(op.scalar_op, Pow):
505+
if op.transform_elemwise.power < 0:
506+
raise NotImplementedError
507+
else:
508+
raise NotImplementedError
509+
470510
input_icdf = _icdf_helper(measurable_input, value)
471511
icdf = op.transform_elemwise.forward(input_icdf, *other_inputs)
472512

@@ -871,7 +911,7 @@ def __init__(self, power=None):
871911
super().__init__()
872912

873913
def forward(self, value, *inputs):
874-
pt.power(value, self.power)
914+
return pt.power(value, self.power)
875915

876916
def backward(self, value, *inputs):
877917
inv_power = 1 / self.power

tests/logprob/test_transforms.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from pytensor.graph.fg import FunctionGraph
4848
from pytensor.scan import scan
4949

50+
from pymc.distributions.continuous import Cauchy
5051
from pymc.distributions.transforms import _default_transform, log, logodds
5152
from pymc.logprob.abstract import MeasurableVariable, _logprob
5253
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
@@ -764,14 +765,24 @@ def test_exp_transform_rv():
764765
y_rv.name = "y"
765766

766767
y_vv = y_rv.clone()
767-
logprob = logp(y_rv, y_vv)
768-
logp_fn = pytensor.function([y_vv], logprob)
768+
logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
769+
logcdf_fn = pytensor.function([y_vv], logcdf(y_rv, y_vv))
770+
icdf_fn = pytensor.function([y_vv], icdf(y_rv, y_vv))
769771

770772
y_val = [-2.0, 0.1, 0.3]
773+
q_val = [0.2, 0.5, 0.9]
771774
np.testing.assert_allclose(
772775
logp_fn(y_val),
773776
sp.stats.lognorm(s=1).logpdf(y_val),
774777
)
778+
np.testing.assert_almost_equal(
779+
logcdf_fn(y_val),
780+
sp.stats.lognorm(s=1).logcdf(y_val),
781+
)
782+
np.testing.assert_almost_equal(
783+
icdf_fn(q_val),
784+
sp.stats.lognorm(s=1).ppf(q_val),
785+
)
775786

776787

777788
def test_log_transform_rv():
@@ -811,14 +822,24 @@ def test_loc_transform_rv(self, rv_size, loc_type, addition):
811822
logprob = logp(y_rv, y_vv)
812823
assert_no_rvs(logprob)
813824
logp_fn = pytensor.function([loc, y_vv], logprob)
825+
logcdf_fn = pytensor.function([loc, y_vv], logcdf(y_rv, y_vv))
826+
icdf_fn = pytensor.function([loc, y_vv], icdf(y_rv, y_vv))
814827

815828
loc_test_val = np.full(rv_size, 4.0)
816829
y_test_val = np.full(rv_size, 1.0)
817-
830+
q_test_val = np.full(rv_size, 0.7)
818831
np.testing.assert_allclose(
819832
logp_fn(loc_test_val, y_test_val),
820833
sp.stats.norm(loc_test_val, 1).logpdf(y_test_val),
821834
)
835+
np.testing.assert_allclose(
836+
logcdf_fn(loc_test_val, y_test_val),
837+
sp.stats.norm(loc_test_val, 1).logcdf(y_test_val),
838+
)
839+
np.testing.assert_allclose(
840+
icdf_fn(loc_test_val, q_test_val),
841+
sp.stats.norm(loc_test_val, 1).ppf(q_test_val),
842+
)
822843

823844
@pytest.mark.parametrize(
824845
"rv_size, scale_type, product",
@@ -840,23 +861,37 @@ def test_scale_transform_rv(self, rv_size, scale_type, product):
840861
logprob = logp(y_rv, y_vv)
841862
assert_no_rvs(logprob)
842863
logp_fn = pytensor.function([scale, y_vv], logprob)
864+
logcdf_fn = pytensor.function([scale, y_vv], logcdf(y_rv, y_vv))
865+
icdf_fn = pytensor.function([scale, y_vv], icdf(y_rv, y_vv))
843866

844867
scale_test_val = np.full(rv_size, 4.0)
845868
y_test_val = np.full(rv_size, 1.0)
846-
869+
q_test_val = np.full(rv_size, 0.3)
847870
np.testing.assert_allclose(
848871
logp_fn(scale_test_val, y_test_val),
849872
sp.stats.norm(0, scale_test_val).logpdf(y_test_val),
850873
)
874+
np.testing.assert_allclose(
875+
logcdf_fn(scale_test_val, y_test_val),
876+
sp.stats.norm(0, scale_test_val).logcdf(y_test_val),
877+
)
878+
np.testing.assert_allclose(
879+
icdf_fn(scale_test_val, q_test_val),
880+
sp.stats.norm(0, scale_test_val).ppf(q_test_val),
881+
)
851882

852883
def test_negated_rv_transform(self):
853884
x_rv = -pt.random.halfnormal()
854885
x_rv.name = "x"
855886

856887
x_vv = x_rv.clone()
857-
x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))
888+
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
889+
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))
890+
x_icdf_fn = pytensor.function([x_vv], icdf(x_rv, x_vv))
858891

859892
np.testing.assert_allclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))
893+
np.testing.assert_allclose(x_logcdf_fn(-1.5), sp.stats.halfnorm.logsf(1.5))
894+
np.testing.assert_allclose(x_icdf_fn(0.3), -sp.stats.halfnorm.ppf(1 - 0.3))
860895

861896
def test_subtracted_rv_transform(self):
862897
# Choose base RV that is asymmetric around zero
@@ -899,25 +934,55 @@ def test_reciprocal_rv_transform(self, numerator):
899934

900935
x_vv = x_rv.clone()
901936
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
937+
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))
938+
939+
with pytest.raises(NotImplementedError):
940+
icdf(x_rv, x_vv)
902941

903942
x_test_val = np.r_[-0.5, 1.5]
904943
np.testing.assert_allclose(
905944
x_logp_fn(x_test_val),
906945
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
907946
)
947+
np.testing.assert_allclose(
948+
x_logcdf_fn(x_test_val),
949+
sp.stats.invgamma(shape, scale=scale * numerator).logcdf(x_test_val),
950+
)
951+
952+
def test_reciprocal_real_rv_transform(self):
953+
# 1 / Cauchy(mu, sigma) = Cauchy(mu / (mu^2 + sigma ^2), sigma / (mu ^ 2, sigma ^ 2))
954+
test_value = [-0.5, 0.9]
955+
test_rv = Cauchy.dist(1, 2, size=(2,)) ** (-1)
956+
957+
np.testing.assert_allclose(
958+
logp(test_rv, test_value).eval(),
959+
sp.stats.cauchy(1 / 5, 2 / 5).logpdf(test_value),
960+
)
961+
np.testing.assert_allclose(
962+
logcdf(test_rv, test_value).eval(),
963+
sp.stats.cauchy(1 / 5, 2 / 5).logcdf(test_value),
964+
)
965+
with pytest.raises(NotImplementedError):
966+
icdf(test_rv, test_value)
908967

909968
def test_sqr_transform(self):
910-
# The square of a unit normal is a chi-square with 1 df
911-
x_rv = pt.random.normal(0, 1, size=(4,)) ** 2
969+
# The square of a normal with unit variance is a noncentral chi-square with 1 df and nc = mean ** 2
970+
x_rv = pt.random.normal(0.5, 1, size=(4,)) ** 2
912971
x_rv.name = "x"
913972

914973
x_vv = x_rv.clone()
915974
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
916975

976+
with pytest.raises(NotImplementedError):
977+
logcdf(x_rv, x_vv)
978+
979+
with pytest.raises(NotImplementedError):
980+
icdf(x_rv, x_vv)
981+
917982
x_test_val = np.r_[-0.5, 0.5, 1, 2.5]
918983
np.testing.assert_allclose(
919984
x_logp_fn(x_test_val),
920-
sp.stats.chi2(df=1).logpdf(x_test_val),
985+
sp.stats.ncx2(df=1, nc=0.5**2).logpdf(x_test_val),
921986
)
922987

923988
def test_sqrt_transform(self):
@@ -927,12 +992,29 @@ def test_sqrt_transform(self):
927992

928993
x_vv = x_rv.clone()
929994
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
995+
x_logcdf_fn = pytensor.function([x_vv], logcdf(x_rv, x_vv))
930996

931997
x_test_val = np.r_[-2.5, 0.5, 1, 2.5]
932998
np.testing.assert_allclose(
933999
x_logp_fn(x_test_val),
9341000
sp.stats.chi(df=3).logpdf(x_test_val),
9351001
)
1002+
np.testing.assert_allclose(
1003+
x_logcdf_fn(x_test_val),
1004+
sp.stats.chi(df=3).logcdf(x_test_val),
1005+
)
1006+
1007+
# ICDF is not implemented for chisquare, so we have to test with another identity
1008+
# sqrt(exponential(lam)) = rayleigh(1 / sqrt(2 * lam))
1009+
lam = 2.5
1010+
y_rv = pt.sqrt(pt.random.exponential(scale=1 / lam))
1011+
y_vv = x_rv.clone()
1012+
y_icdf_fn = pytensor.function([y_vv], icdf(y_rv, y_vv))
1013+
q_test_val = np.r_[0.2, 0.5, 0.7, 0.9]
1014+
np.testing.assert_allclose(
1015+
y_icdf_fn(q_test_val),
1016+
(1 / np.sqrt(2 * lam)) * np.sqrt(-2 * np.log(1 - q_test_val)),
1017+
)
9361018

9371019
@pytest.mark.parametrize("power", (-3, -1, 1, 5, 7))
9381020
def test_negative_value_odd_power_transform(self, power):
@@ -947,7 +1029,7 @@ def test_negative_value_odd_power_transform(self, power):
9471029
assert np.isfinite(x_logp_fn(-1))
9481030

9491031
@pytest.mark.parametrize("power", (-2, 2, 4, 6, 8))
950-
def test_negative_value_even_power_transform(self, power):
1032+
def test_negative_value_even_power_transform_logp(self, power):
9511033
# check that negative values and odd powers evaluate to -inf logp
9521034
x_rv = pt.random.normal() ** power
9531035
x_rv.name = "x"
@@ -959,7 +1041,7 @@ def test_negative_value_even_power_transform(self, power):
9591041
assert np.isneginf(x_logp_fn(-1))
9601042

9611043
@pytest.mark.parametrize("power", (-1 / 3, -1 / 2, 1 / 2, 1 / 3))
962-
def test_negative_value_frac_power_transform(self, power):
1044+
def test_negative_value_frac_power_transform_logp(self, power):
9631045
# check that negative values and fractional powers evaluate to -inf logp
9641046
x_rv = pt.random.normal() ** power
9651047
x_rv.name = "x"
@@ -979,8 +1061,12 @@ def test_absolute_rv_transform(test_val):
9791061
x_vv = x_rv.clone()
9801062
y_vv = y_rv.clone()
9811063
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
982-
y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
1064+
with pytest.raises(NotImplementedError):
1065+
logcdf(x_rv, x_vv)
1066+
with pytest.raises(NotImplementedError):
1067+
icdf(x_rv, x_vv)
9831068

1069+
y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
9841070
np.testing.assert_allclose(x_logp_fn(test_val), y_logp_fn(test_val))
9851071

9861072

@@ -1022,6 +1108,10 @@ def test_cosh_rv_transform():
10221108

10231109
vv = rv.clone()
10241110
rv_logp = logp(rv, vv)
1111+
with pytest.raises(NotImplementedError):
1112+
logcdf(rv, vv)
1113+
with pytest.raises(NotImplementedError):
1114+
icdf(rv, vv)
10251115

10261116
transform = CoshTransform()
10271117
[back_neg, back_pos] = transform.backward(vv)
@@ -1083,37 +1173,3 @@ def test_invalid_broadcasted_transform_rv_fails():
10831173
# This logp derivation should fail or count only once the values that are broadcasted
10841174
logprob = logp(y_rv, y_vv)
10851175
assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()
1086-
1087-
1088-
def test_logcdf_measurable_transform():
1089-
x = pt.exp(pt.random.uniform(0, 1))
1090-
value = x.type()
1091-
logcdf_fn = pytensor.function([value], logcdf(x, value))
1092-
1093-
assert logcdf_fn(0) == -np.inf
1094-
np.testing.assert_allclose(logcdf_fn(np.exp(0.5)), np.log(0.5))
1095-
np.testing.assert_allclose(logcdf_fn(5), 0)
1096-
1097-
1098-
def test_logcdf_measurable_non_injective_fails():
1099-
x = pt.abs(pt.random.uniform(0, 1))
1100-
value = x.type()
1101-
with pytest.raises(NotImplementedError):
1102-
logcdf(x, value)
1103-
1104-
1105-
def test_icdf_measurable_transform():
1106-
x = pt.exp(pt.random.uniform(0, 1))
1107-
value = x.type()
1108-
icdf_fn = pytensor.function([value], icdf(x, value))
1109-
1110-
np.testing.assert_allclose(icdf_fn(1e-16), 1)
1111-
np.testing.assert_allclose(icdf_fn(0.5), np.exp(0.5))
1112-
np.testing.assert_allclose(icdf_fn(1 - 1e-16), np.e)
1113-
1114-
1115-
def test_icdf_measurable_non_injective_fails():
1116-
x = pt.abs(pt.random.uniform(0, 1))
1117-
value = x.type()
1118-
with pytest.raises(NotImplementedError):
1119-
icdf(x, value)

0 commit comments

Comments
 (0)