-
Notifications
You must be signed in to change notification settings - Fork 228
Description
Following the discussion about sampling from the prior and missing
s (#786), I put some thoughts together about the model syntax. Maybe this is helpful for your refactoring plans.
Currently, when we define a model, we already have to decide in advance on which random variables (RV) we want to condition on later. It would be nice to decouple this steps:
- define the model, i.e. the joint distribution
- condition on any RV we have data
This is already possible to a large degree, if we list manually all RV as arguments:
# -----------
# define model
# Note, x_det is deterministic data, not a RV!
# defines p(y, a, b, s ; x_det)
@model model1(x_det, y) = begin
a ~ Normal(0, 10)
b ~ Normal(0, 10)
s ~ Exponential(1)
yhat = a*x_det + b
y ~ Normal(yhat, s)
end
# defines also p(y, a, b, s ; x_det)
# with the added advantage that we
# can also condition on a, b, and s, see examples below
@model model2(x_det, y, a, b, s) = begin
a ~ Normal(0, 10)
b ~ Normal(0, 10)
s ~ Exponential(1)
yhat = a*x_det + b
y ~ Normal(yhat, s)
end
# -----------
# sample
# p(a,b,s | y; x_det)
chain1 = sample(model1(10.0, 20.0), sampler)
# same as above: p(a,b,s | y; x_det)
chain2 = sample(model2(10.0, 20.0), sampler)
# same as above: p(a,b,s | y; x_det)
chain3 = sample(model2(10.0, 20.0, missing, missing, missing), sampler)
# with model2 we can also condition on other RV:
# p(y, a | b, s; x_det)
chain4 = sample(model2(10.0, missing, missing, 2.2, 0.2), sampler)
# Conditioning on x_det is not meaningful, because it is not a RV
# p(x_det | y, a, b, s ) -> MethodError
chain5 = sample(model2(missing, 3.3, 1.1, 2.2, 0.2), sampler)
So model1
and model2
are identical, but the later one is much more flexible. Therefore, I was wondering if it is a good idea, to generate something like model2
automatically. I could imagine a syntax like this:
# The user only specifies the deterministic variables as argument
@model model(x_det) = begin
# prior
a ~ Normal(0, 10)
b ~ Normal(0, 10)
s ~ Exponential(1)
yhat = a*x_det + b
y ~ Normal(yhat, s)
end
# this would get translated to:
@model model(x_det; y=missing, a=missing, b=missing, s=missing) = begin
...
end
# Now we can condition on every RV we want.
# However, deterministic variables must always be given.
sample(model2(1.1, y=3.3, s=0.1), sampler)
For vector RV this would need some additional thoughts on how to pass the dimensions.
Conceptually this seems neat, however I cannot judge how difficult such an implementation would be.