Usage
This section presents a series of examples that illustrate different uses of JunctionTrees.jl. Load the package to run the examples.
using JunctionTrees
Example 1
Calculates the posterior marginal of each variable in the input graph. The input graph should be defined in the UAI model file format.
algo = compile_algo("problems/asia/asia.uai")
eval(algo)
obsvars, obsvals = Int64[], Int64[]
marginals = run_algo(obsvars, obsvals)
8-element Vector{Factor{Float64, 1}}:
Factor{Float64, 1}((1,), [0.5, 0.5000000000000001])
Factor{Float64, 1}((2,), [0.5, 0.49999999999999994])
Factor{Float64, 1}((3,), [0.5, 0.5])
Factor{Float64, 1}((4,), [0.49999999999999994, 0.5])
Factor{Float64, 1}((5,), [0.5, 0.5])
Factor{Float64, 1}((6,), [0.5, 0.5])
Factor{Float64, 1}((7,), [0.49999999999999994, 0.5])
Factor{Float64, 1}((8,), [0.5, 0.5000000000000001])
Example 2
Calculates the posterior marginal of each variable in the input graph given some evidence. The input graph should be defined in the UAI model file format. The evidence variables and values should be given in the UAI evidence file format.
algo = compile_algo(
"problems/asia/asia.uai",
uai_evid_filepath = "problems/asia/asia.uai.evid")
eval(algo)
obsvars, obsvals = JunctionTrees.read_uai_evid_file("problems/asia/asia.uai.evid")
marginals = run_algo(obsvars, obsvals)
8-element Vector{Factor{Float64, 1}}:
Factor{Float64, 1}((1,), [1.0, 0.0])
Factor{Float64, 1}((2,), [0.9999999998365611, 1.6343887218216758e-10])
Factor{Float64, 1}((3,), [0.9999926189958698, 7.381004130341263e-6])
Factor{Float64, 1}((4,), [0.9999999999455195, 5.448042829136154e-11])
Factor{Float64, 1}((5,), [0.9999999999455195, 5.4480428291361534e-11])
Factor{Float64, 1}((6,), [0.9999999998365611, 1.634388721821676e-10])
Factor{Float64, 1}((7,), [1.0, 0.0])
Factor{Float64, 1}((8,), [0.9999926189958698, 7.381004130341263e-6])
Example 3
Same as the previous example with the difference that a pre-constructed junction tree (which is passed as an argument) is used. This junction tree should be defined in the PACE graph format.
algo = compile_algo(
"problems/asia/asia.uai",
uai_evid_filepath = "problems/asia/asia.uai.evid",
td_filepath = "problems/asia/asia.td")
eval(algo)
obsvars, obsvals = JunctionTrees.read_uai_evid_file("problems/asia/asia.uai.evid")
marginals = run_algo(obsvars, obsvals)
8-element Vector{Factor{Float64, 1}}:
Factor{Float64, 1}((1,), [1.0, 0.0])
Factor{Float64, 1}((2,), [0.9999999998365611, 1.6343887218216756e-10])
Factor{Float64, 1}((3,), [0.9999926189958698, 7.381004130341263e-6])
Factor{Float64, 1}((4,), [0.9999999999455195, 5.448042829136153e-11])
Factor{Float64, 1}((5,), [0.9999999999455195, 5.4480428291361534e-11])
Factor{Float64, 1}((6,), [0.9999999998365611, 1.6343887218216758e-10])
Factor{Float64, 1}((7,), [1.0, 0.0])
Factor{Float64, 1}((8,), [0.9999926189958698, 7.381004130341263e-6])
Example 4
Returns the expression of the junction tree algorithm up to the backward pass stage.
backward_pass_expr = compile_algo( "problems/asia/asia.uai", last_stage = BackwardPass)
:(function run_algo(obsvars::Vector{Int64}, obsvals::Vector{Int64})
begin
pot_1 = Factor{Float64, 2}((1, 7), [368.08 0.0027168; 0.0027168 368.08])
pot_2 = Factor{Float64, 2}((4, 8), [368.08 0.0027168; 0.0027168 368.08])
pot_3 = Factor{Float64, 2}((3, 4), [368.08 0.0027168; 0.0027168 368.08])
pot_4 = Factor{Float64, 3}((4, 5, 7), [135482.8864 0.999999744; 0.999999744 7.381002240000001e-6;;; 7.381002240000001e-6 0.999999744; 0.999999744 135482.8864])
pot_5 = Factor{Float64, 3}((2, 4, 6), [135482.8864 0.999999744; 7.381002240000001e-6 0.999999744;;; 0.999999744 7.381002240000001e-6; 0.999999744 135482.8864])
pot_6 = Factor{Float64, 3}((4, 5, 6), [135482.8864 7.381002240000001e-6; 0.999999744 0.999999744;;; 0.999999744 0.999999744; 7.381002240000001e-6 135482.8864])
end
begin
msg_1_4 = sum(pot_1, 1)
msg_2_4 = sum(pot_2, 8)
msg_3_5 = sum(pot_3, 3)
msg_5_6 = sum(prod(msg_3_5, pot_5), 2)
msg_6_4 = sum(prod(msg_5_6, pot_6), 6)
end
begin
msg_4_1 = sum(prod(msg_2_4, msg_6_4, pot_4), 4, 5)
msg_4_2 = sum(prod(msg_1_4, msg_6_4, pot_4), 5, 7)
msg_4_6 = sum(prod(msg_1_4, msg_2_4, pot_4), 7)
msg_6_5 = sum(prod(msg_4_6, pot_6), 5)
msg_5_3 = sum(prod(msg_6_5, pot_5), 2, 6)
end
end)
The stages supported are:
instances(LastStage)
(ForwardPass, BackwardPass, JointMarginals, UnnormalizedMarginals, Marginals)