Public Documentation
Documentation for JunctionTrees.jl
's public interface.
Index
JunctionTrees.BackwardPass
JunctionTrees.ForwardPass
JunctionTrees.JointMarginals
JunctionTrees.Marginals
JunctionTrees.UnnormalizedMarginals
Public interface
Modules
JunctionTrees
— ModuleMain module for JunctionTrees.jl
– a Julia implementation of the junction tree algorithm.
One main function is exported from this module for public use:
compile_algo
. Compiles and returns an expression that computes the posterior marginals of the model given evidence using the junction tree algorithm.
Exports
Types
JunctionTrees.Factor
— Typestruct Factor{T, N}
Fields
vars
vals
Encodes a discrete function over the set of variables vars
that maps each instantiation of vars
into a nonnegative number in vals
.
JunctionTrees.LastStage
— TypeEnumerated type used to select up to which stage an expression of the junction tree algorithm should be returned after calling compile_algo
.
Functions
JunctionTrees.compile_algo
— Functioncompile_algo(
uai_filepath::AbstractString;
uai_evid_filepath,
td_filepath,
apply_partial_evaluation,
last_stage,
smart_root_selection,
factor_eltype,
use_omeinsum,
correct_fp_overflows
) -> Expr
Return an expression of the junction tree algorithm that extracts the marginals of all the variables in the model.
Arguments
uai_filepath::AbstractString
: path to the model file defined in the UAI model file format.uai_evid_filepath::AbstractString = ""
: path to the evidence file defined in the UAI evidence file format.td_filepath::AbstractString = ""
: path to a pre-constructed junction tree defined in the PACE graph format.apply_partial_evaluation::Bool = false
: optimize the algorithm using partial evaluation.last_stage::LastStage = Marginals
: return an expression up to the given stage. The options areForwardPass
,BackwardPass
,JointMarginals
,UnnormalizedMarginals
andMarginals
.smart_root_selection::Bool = true
: select as root the cluster with the largest state space.factor_eltype::DataType = Float64
: type used to represent the factor values.use_omeinsum::Bool = false
: use the OMEinsum tensor network contraction package as backend for the factor operations.correct_fp_overflows::Bool = false
: normalize messages in the propagation phase that cause an overflow.
Examples
package_root_dir = pathof(JunctionTrees) |> dirname |> dirname
uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai")
algo = compile_algo(uai_filepath)
eval(algo)
obsvars, obsvals = Int64[], Int64[]
marginals = run_algo(obsvars, obsvals)
# output
6-element Vector{Factor{Float64, 1}}:
Factor{Float64, 1}((1,), [0.33480077287635474, 0.33039845424729053, 0.33480077287635474])
Factor{Float64, 1}((2,), [0.378700415763991, 0.621299584236009])
Factor{Float64, 1}((3,), [0.3632859624875086, 0.6367140375124913])
Factor{Float64, 1}((4,), [0.6200692707149191, 0.37993072928508087])
Factor{Float64, 1}((5,), [0.649200314859223, 0.350799685140777])
Factor{Float64, 1}((6,), [0.5968155611613972, 0.4031844388386027])
package_root_dir = pathof(JunctionTrees) |> dirname |> dirname
uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai")
uai_evid_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai.evid")
algo = compile_algo(
uai_filepath,
uai_evid_filepath = uai_evid_filepath)
eval(algo)
obsvars, obsvals = JunctionTrees.read_uai_evid_file(uai_evid_filepath)
marginals = run_algo(obsvars, obsvals)
# output
6-element Vector{Factor{Float64, 1}}:
Factor{Float64, 1}((1,), [1.0, 0.0, 0.0])
Factor{Float64, 1}((2,), [0.0959432982733719, 0.9040567017266281])
Factor{Float64, 1}((3,), [0.07863089300137578, 0.9213691069986242])
Factor{Float64, 1}((4,), [0.8440129077674895, 0.15598709223251056])
Factor{Float64, 1}((5,), [0.9015456486772953, 0.09845435132270475])
Factor{Float64, 1}((6,), [0.6118571666785584, 0.3881428333214415])
Base.prod
— Functionprod(A::Factor{T}, B::Factor{T}) -> Factor
Compute a factor product of tables A
and B
.
Examples
A = Factor{Float64,2}((2, 3), [0.5 0.7; 0.1 0.2])
B = Factor{Float64,2}((1, 2), [0.5 0.8; 0.1 0.0; 0.3 0.9])
prod(A, B)
# output
Factor{Float64, 3}((1, 2, 3), [0.25 0.08000000000000002; 0.05 0.0; 0.15 0.09000000000000001;;; 0.35 0.16000000000000003; 0.06999999999999999 0.0; 0.21 0.18000000000000002])
prod(F::Factor{T}...) -> Factor
Compute a factor product of an arbitrary number of factors.
Base.sum
— Functionsum(
A::Factor{T, ND},
V::Tuple{Vararg{Int64, N}} where N
) -> Factor
Sum out the variables in V
from factor A.
Examples
A = Factor{Float64,2}((1, 2), [0.59 0.41; 0.22 0.78])
sum(A, (2,))
# output
Factor{Float64, 1}((1,), [1.0, 1.0])
sum(A::Factor, V::Int64...) -> Factor
Sum out an arbitrary number of variables from factor A.
Examples
A = Factor{Float64,3}((1, 2, 3), cat([0.25 0.08; 0.05 0.0; 0.15 0.09],
[0.35 0.16; 0.07 0.0; 0.21 0.18], dims=3))
sum(A, 1, 2)
# output
Factor{Float64, 1}((3,), [0.6199999999999999, 0.97])
JunctionTrees.redu
— Functionredu(A::Factor{T}, vars::Tuple, vals::Tuple) -> Factor
Reduce/invalidate all entries in A
that are not consitent with the evidence passed in vars
and vals
, where each variable in vars
is assigned the corresponding value in vals
.
Examples
A = Factor{Float64,3}((1, 2, 3), cat([0.25 0.08; 0.05 0.0; 0.15 0.09],
[0.35 0.16; 0.07 0.0; 0.21 0.18], dims=3))
obs_vars = (3,)
obs_vals = (1,)
redu(A, obs_vars, obs_vals)
# output
Factor{Float64, 3}((1, 2, 3), [0.25 0.08; 0.05 0.0; 0.15 0.09;;; 0.0 0.0; 0.0 0.0; 0.0 0.0])
JunctionTrees.norm
— Functionnorm(A::Factor{T, N}) -> Factor
Normalize the values in Factor A such they sum up to 1.
Examples
A = Factor{Float64,2}((1, 2), [0.2 0.4; 0.6 0.8])
norm(A)
# output
Factor{Float64, 2}((1, 2), [0.1 0.2; 0.3 0.4])
Constants
JunctionTrees.ForwardPass
— ConstantWhen assigned to the keyword argument last_stage
of compile_algo
, an expression up to and including the forward pass is returned.
JunctionTrees.BackwardPass
— ConstantWhen assigned to the keyword argument last_stage
of compile_algo
, an expression up to and including the backward pass is returned.
JunctionTrees.JointMarginals
— ConstantWhen assigned to the keyword argument last_stage
of compile_algo
, an expression that computes the cluster joint marginals is returned.
JunctionTrees.UnnormalizedMarginals
— ConstantWhen assigned to the keyword argument last_stage
of compile_algo
, an expression that computes the joint marginals is returned.
JunctionTrees.Marginals
— ConstantWhen assigned to the keyword argument last_stage
of compile_algo
, an expression that computes the posterior marginals is returned (default).