Class HybridCgModel.HybridCgPm

java.lang.Object
edu.cmu.tetrad.hybridcg.HybridCgModel.HybridCgPm
All Implemented Interfaces:
Pm, Serializable
Enclosing class:
HybridCgModel

public static final class HybridCgModel.HybridCgPm extends Object implements Pm, Serializable
The HybridCgPm class represents a structural model for hybrid Bayesian networks, which may include both discrete and continuous variables. It provides various methods to manage the graph structure, nodes, discrete and continuous parent relationships, and discretize continuous values based on predefined cutpoints. The class is also responsible for indexing and generating the local tables needed for conditional probability computations across mixed data types.
See Also:
  • Constructor Summary

    Constructors
    Constructor
    Description
    HybridCgPm(Graph dag, List<Node> nodeOrder, Map<Node,Boolean> discreteFlags, Map<Node,List<String>> categoryMap)
    Constructs a HybridCgPm instance based on the provided directed acyclic graph (DAG), node ordering, discrete flags for nodes, and a mapping of node categories.
  • Method Summary

    Modifier and Type
    Method
    Description
    void
    autoCutpointsForDiscreteChild(Node child, DataSet data, int binsPerParent)
    Populate cutpoints for each continuous parent of a DISCRETE child using equal-frequency binning.
    int
    discretizeFor(Node child, Node contParent, double value)
    Discretizes the given value of a continuous parent node for a specific discrete child node into a bin index based on predefined cutpoints.
    int
    getCardinality(int nodeIndex)
    Retrieves the cardinality (number of categories) for a discrete node at a given index.
    getCategories(int nodeIndex)
    Retrieves the list of categories for a discrete node at a given index.
    int[]
    getContinuousParents(int nodeIndex)
    Retrieves the continuous parents of a node at a given index.
    Optional<double[][]>
    Retrieves the cutpoints for continuous parents of a discrete child node.
    int[]
    getDiscreteParents(int nodeIndex)
    Retrieves the discrete parents of a node at a given index.
    Retrieves the directed acyclic graph (DAG) associated with this model.
    Retrieves the ordered list of nodes in this probabilistic model.
    int
    getNumRows(int nodeIndex)
    Retrieves the number of rows in the local table for a given node index.
    int[]
    getParents(int y)
    Retrieves the parents of a node in the hybrid causal graph model.
    int[]
    getRowDims(int nodeIndex)
    Computes and returns the row dimensions for a specific node in the network based on its discrete and continuous parents.
    int
    getRowIndex(int nodeIndex, int[] discVals, int[] contBinVals)
    Computes the row index for a given node in the model based on the state indices of its discrete and continuous parents.
    int
    Retrieves the index of a given node within the ordered list of nodes.
    boolean
    isDiscrete(int nodeIndex)
    Checks if a node at a given index is discrete.
    int
    rowIndexForCase(int nodeIndex, double[] contParentValues)
    Convenience overload when the child has NO discrete parents.
    int
    rowIndexForCase(int nodeIndex, int[] discParentStates, double[] contParentValues)
    Compute the local-table row index for a single data case.
    int
    rowIndexForCase(int nodeIndex, DataSet data, int row)
    Convenience overload that reads parent states/values from a DataSet row.
    void
    setContParentCutpointsForDiscreteChild(Node child, Map<Node,double[]> cutpointsByContParent)
    Sets the continuous parent cutpoints for a specified discrete child node.

    Methods inherited from class java.lang.Object

    equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Constructor Details

    • HybridCgPm

      public HybridCgPm(Graph dag, List<Node> nodeOrder, Map<Node,Boolean> discreteFlags, Map<Node,List<String>> categoryMap)
      Constructs a HybridCgPm instance based on the provided directed acyclic graph (DAG), node ordering, discrete flags for nodes, and a mapping of node categories.

      The HybridCgPm represents a probabilistic model that supports a mix of continuous and discrete variables with parent dependencies. This constructor initializes the model structure and categorization for nodes based on the supplied inputs, ensuring the specified order and discrete/continuous classifications are accounted for.

      Parameters:
      dag - the directed acyclic graph representing dependencies; must not be null
      nodeOrder - the ordered list of nodes, defining the sequence used; must not be null
      discreteFlags - a map indicating whether each node is discrete (true) or continuous (false)
      categoryMap - a map providing a list of category strings for discrete nodes
  • Method Details

    • getGraph

      public Graph getGraph()
      Retrieves the directed acyclic graph (DAG) associated with this model.
      Returns:
      the DAG representing the structure of this probabilistic model
    • getNodes

      public Node[] getNodes()
      Retrieves the ordered list of nodes in this probabilistic model.
      Returns:
      the ordered list of nodes
    • indexOf

      public int indexOf(Node v)
      Retrieves the index of a given node within the ordered list of nodes.
      Parameters:
      v - the node to find the index for
      Returns:
      the index of the node, or -1 if not found
    • isDiscrete

      public boolean isDiscrete(int nodeIndex)
      Checks if a node at a given index is discrete.
      Parameters:
      nodeIndex - the index of the node to check
      Returns:
      true if the node is discrete, false otherwise
    • getCategories

      public List<String> getCategories(int nodeIndex)
      Retrieves the list of categories for a discrete node at a given index.
      Parameters:
      nodeIndex - the index of the discrete node
      Returns:
      the list of categories for the node
    • getCardinality

      public int getCardinality(int nodeIndex)
      Retrieves the cardinality (number of categories) for a discrete node at a given index.
      Parameters:
      nodeIndex - the index of the discrete node
      Returns:
      the number of categories for the node
    • getDiscreteParents

      public int[] getDiscreteParents(int nodeIndex)
      Retrieves the discrete parents of a node at a given index.
      Parameters:
      nodeIndex - the index of the node
      Returns:
      an array of indices representing discrete parents
    • getContinuousParents

      public int[] getContinuousParents(int nodeIndex)
      Retrieves the continuous parents of a node at a given index.
      Parameters:
      nodeIndex - the index of the node
      Returns:
      an array of indices representing continuous parents
    • setContParentCutpointsForDiscreteChild

      public void setContParentCutpointsForDiscreteChild(Node child, Map<Node,double[]> cutpointsByContParent)
      Sets the continuous parent cutpoints for a specified discrete child node. This method ensures that for each continuous parent of the discrete child, the respective cutpoints are strictly increasing.
      Parameters:
      child - The discrete child node for which the cutpoints will be set.
      cutpointsByContParent - A map containing the cutpoints for each continuous parent node. The keys in the map are the continuous parent nodes, and the values are arrays of cutpoints, which are expected to be strictly increasing.
      Throws:
      IllegalArgumentException - If the specified child is not discrete, if the cutpoints are missing for any continuous parent, or if the provided cutpoints are not strictly increasing.
    • getContParentCutpointsForDiscreteChild

      public Optional<double[][]> getContParentCutpointsForDiscreteChild(int nodeIndex)
      Retrieves the cutpoints for continuous parents of a discrete child node.
      Parameters:
      nodeIndex - the index of the discrete child node
      Returns:
      an Optional containing the cutpoints array, or empty if not set
    • getNumRows

      public int getNumRows(int nodeIndex)
      Retrieves the number of rows in the local table for a given node index.
      Parameters:
      nodeIndex - the index of the node
      Returns:
      the number of rows in the local table
    • getRowDims

      public int[] getRowDims(int nodeIndex)
      Computes and returns the row dimensions for a specific node in the network based on its discrete and continuous parents.
      Parameters:
      nodeIndex - the index of the node for which the row dimensions are to be computed
      Returns:
      an array of integers representing the row dimensions for the given node
    • getRowIndex

      public int getRowIndex(int nodeIndex, int[] discVals, int[] contBinVals)
      Computes the row index for a given node in the model based on the state indices of its discrete and continuous parents. The method takes into account the dimensions determined by the parents to compute the correct row index.
      Parameters:
      nodeIndex - the index of the node whose row index is to be computed
      discVals - an array of indices representing the states of the discrete parents of the node; the length must match the number of discrete parents
      contBinVals - an array of indices representing the binned states of the continuous parents of the node; required only for discrete nodes with continuous parents, and its length must match the number of continuous parents for the node
      Returns:
      the computed row index, in the range [0, getNumRows(nodeIndex) - 1]
      Throws:
      IllegalArgumentException - if the lengths of discVals or contBinVals do not match the required dimensions
    • discretizeFor

      public int discretizeFor(Node child, Node contParent, double value)
      Discretizes the given value of a continuous parent node for a specific discrete child node into a bin index based on predefined cutpoints.
      Parameters:
      child - the discrete child node whose continuous parent's value will be discretized
      contParent - the continuous parent node associated with the child node
      value - the value of the continuous parent node to discretize
      Returns:
      the bin index (in the range 0 to the number of cutpoints) corresponding to the given value
      Throws:
      IllegalArgumentException - if the specified parent node is not a continuous parent of the child node
      IllegalStateException - if the cutpoints for the child node have not been set
    • autoCutpointsForDiscreteChild

      public void autoCutpointsForDiscreteChild(Node child, DataSet data, int binsPerParent)
      Populate cutpoints for each continuous parent of a DISCRETE child using equal-frequency binning.
      Parameters:
      child - discrete child
      data - data set that contains all variables by name
      binsPerParent - number of bins to use for each continuous parent (>=2 recommended)
    • rowIndexForCase

      public int rowIndexForCase(int nodeIndex, int[] discParentStates, double[] contParentValues)
      Compute the local-table row index for a single data case.

      Rules:

      • Discrete parents contribute their category indices directly.
      • If the child is DISCRETE and has continuous parents, each continuous parent is discretized using the stored cutpoints to a bin index (0..bins-1), and those bin indices extend the row index.
      • If the child is CONTINUOUS, only discrete parents contribute (continuous parents do not add dimensions for continuous children).
      Parameters:
      nodeIndex - index of the child (in this PM's node order)
      discParentStates - length must equal getDiscreteParents(nodeIndex).length; each entry in [0, card-1]
      contParentValues - raw values for the child's continuous parents; for a DISCRETE child length must equal getContinuousParents(nodeIndex).length; ignored for CONTINUOUS child
      Returns:
      row index in [0, getNumRows(nodeIndex)-1]
    • rowIndexForCase

      public int rowIndexForCase(int nodeIndex, double[] contParentValues)
      Convenience overload when the child has NO discrete parents. Useful for tests like: pm.rowIndexForCase(yIdx, new double[]{ xVal }).
      Parameters:
      nodeIndex - child index
      contParentValues - raw values for the child's continuous parents
      Returns:
      the index.
    • rowIndexForCase

      public int rowIndexForCase(int nodeIndex, DataSet data, int row)
      Convenience overload that reads parent states/values from a DataSet row. Child's discrete-parent states are taken from integer columns; continuous-parent values are taken from double columns and binned (if the child is discrete).
      Parameters:
      nodeIndex - child index
      data - dataset containing all variables (by name)
      row - row index into the dataset
      Returns:
      the row.
    • getParents

      public int[] getParents(int y)
      Retrieves the parents of a node in the hybrid causal graph model.
      Parameters:
      y - the index of the node
      Returns:
      an array of indices representing the parents of the node