Class MlBayesIm

java.lang.Object
edu.cmu.tetrad.bayes.MlBayesIm
All Implemented Interfaces:
BayesIm, Simulator, VariableSource, Im, TetradSerializable, Serializable

public final class MlBayesIm extends Object implements BayesIm
Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate this table. The division of labor is as follows. The Dag is responsible for manipulating the basic graphical structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there is no method in either BayesPm or BayesIm to do this. BayesPm stores and manipulates the *categories* of each node in a DAG, considered as a variable in a Bayes net. The number of categories for a variable can be changed there as well as the names for those categories. This class, BayesIm, stores the actual probability tables which are implied by the structures in the other two classes. The implied parameters take the form of conditional probabilities--e.g., P(N=v0|P1=v1, P2=v2, ...), for all nodes and all combinations of their parent categories. The set of all such probabilities is organized in this class as a three-dimensional table of double values. The first dimension corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of combinations of parent categories for that node. The third dimension corresponds to the list of categories for that node itself. Two methods allow these values to be set and retrieved: getWordRatio(int nodeIndex, int rowIndex, int colIndex); and setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, use the method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that you can build the parentVals[] array, use the method getParents(int nodeIndex). To determine the index of a category, use the method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood as variants of the methods above.

This version uses a sparse method for storing the probabilities, where NaNs are not stored. This allows BayesPms with many categories per variable to be estimated from small samples without overflowing memory. The old method of storing probabilities is kept here for backward compatibility, with an internal code flag to indicate which should be used.

Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions.

Version:
$Id: $Id
Author:
josephramsey
See Also:
  • Nested Class Summary

    Nested Classes
    Modifier and Type
    Class
    Description
    static enum 
    An enumeration representing the different types of CptMap.
    static enum 
    The InitializationMethod enum represents different methods of initializing a class object.
  • Field Summary

    Fields
    Modifier and Type
    Field
    Description
    static final int
    Represents a constant value for a random number.
  • Constructor Summary

    Constructors
    Constructor
    Description
    MlBayesIm(BayesIm bayesIm)
    Copy constructor.
    MlBayesIm(BayesPm bayesPm)
    Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?").
    MlBayesIm(BayesPm bayesPm, boolean countsOnly)
    Constructs an instance of MlBayesIm.
    MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, MlBayesIm.InitializationMethod initializationMethod)
    Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values from the old BayesIm provided where posssible.
    MlBayesIm(BayesPm bayesPm, MlBayesIm.InitializationMethod initializationMethod)
    Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM.
  • Method Summary

    Modifier and Type
    Method
    Description
    void
    clearRow(int nodeIndex, int rowIndex)
    Clears all values in the specified row of a table.
    void
    clearTable(int nodeIndex)
    Clears the table by clearing all rows for the given node.
    boolean
    Determines whether the specified object is equal to this Bayes net.
    Getter for the field bayesPm.
    int
    getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm)
    Returns the corresponding node index in the given BayesIm based on the node index in this BayesIm.
    A flag indicating whether to use CptMaps or not.
    getDag.
    getMeasuredNodes.
    getNode(int nodeIndex)
    Retrieves the node at the specified index.
    getNode.
    int
    Returns the index of the given node in the nodes array.
    int
    getNumColumns(int nodeIndex)
    Returns the number of columns in the specified node.
    int
    getNumNodes.
    int
    getNumParents(int nodeIndex)
    Returns the number of parents for the given node.
    int
    getNumRows(int nodeIndex)
    Retrieves the number of rows in the specified node.
    static List<String>
    getParameterNames.
    int
    getParent(int nodeIndex, int parentIndex)
    Retrieves the parent of a node at the specified index.
    int
    getParentDim(int nodeIndex, int parentIndex)
    Retrieves the value of the parent dimension for a given node and parent index.
    int[]
    getParentDims(int nodeIndex)
    Returns a copy of the dimensions of the parent node at the specified index.
    int[]
    getParents(int nodeIndex)
    Returns an array containing the parents of the specified node.
    int
    getParentValue(int nodeIndex, int rowIndex, int colIndex)
    Retrieves the value of the parent node at the specified row and column index.
    int[]
    getParentValues(int nodeIndex, int rowIndex)
    Returns an integer array containing the parent values for a given node index and row index.
    double
    getProbability(int nodeIndex, int rowIndex, int colIndex)
    Returns the probability for a given node in the table.
    int
    getRowIndex(int nodeIndex, int[] values)
    Returns the row index corresponding to the given node index and combination of parent values.
    getVariableNames.
    getVariables.
    boolean
    isIncomplete(int nodeIndex)
    Checks if the specified table has any incomplete rows.
    boolean
    isIncomplete(int nodeIndex, int rowIndex)
    Checks if the specified row of a table is incomplete, i.e., if any of the columns have a NaN value.
    void
    Normalizes all rows in the tables associated with each of node in turn.
    void
    normalizeNode(int nodeIndex)
    Normalizes the specified node by invoking the normalizeRow(int, int) method on each row of the node.
    void
    normalizeRow(int nodeIndex, int rowIndex)
    Normalizes the probabilities of a given row in a node.
    void
    randomizeIncompleteRows(int nodeIndex)
    Randomizes the incomplete rows in the specified node's table.
    void
    randomizeRow(int nodeIndex, int rowIndex)
    Randomizes the values of a row in a table for a given node.
    void
    randomizeTable(int nodeIndex)
    Randomizes the table for a given node.
    static MlBayesIm
    Generates a simple exemplar of this class to test serialization.
    void
    setCountMap(int nodeIndex, CptMapCounts countMap)
    Sets the count map for a specific node index in the Bayesian network.
    void
    setProbability(int nodeIndex, double[][] probMatrix)
    Sets the probability for the given node.
    void
    setProbability(int nodeIndex, int rowIndex, int colIndex, double value)
    Sets the probability value for a specific node, row, and column in the probability table.
    simulateData(int sampleSize, boolean latentDataSaved)
    Simulates a data set.
    simulateData(int sampleSize, boolean latentDataSaved, int[] tiers)
    Simulates a sample with the given sample size.
    simulateData(DataSet dataSet, boolean latentDataSaved)
    Simulates data for the given data set.
    simulateData(DataSet dataSet, boolean latentDataSaved, int[] tiers)
    simulateData.
    Prints out the probability table for each variable.

    Methods inherited from class java.lang.Object

    getClass, hashCode, notify, notifyAll, wait, wait, wait
  • Field Details

    • RANDOM

      public static final int RANDOM
      Represents a constant value for a random number. The value of this constant is 1.
      See Also:
  • Constructor Details

    • MlBayesIm

      public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException
      Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?").
      Parameters:
      bayesPm - the given Bayes PM. Carries with it the underlying graph model.
      Throws:
      IllegalArgumentException - if the array of nodes provided is not a permutation of the nodes contained in the bayes parametric model provided.
    • MlBayesIm

      public MlBayesIm(BayesPm bayesPm, MlBayesIm.InitializationMethod initializationMethod) throws IllegalArgumentException
      Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM. If initialized manually, all values will be set to Double.NaN ("?") in each row; if initialized randomly, all values will distribute randomly in each row.
      Parameters:
      bayesPm - the given Bayes PM. Carries with it the underlying graph model.
      initializationMethod - either MANUAL or RANDOM.
      Throws:
      IllegalArgumentException - if the array of nodes provided is not a permutation of the nodes contained in the bayes parametric model provided.
    • MlBayesIm

      public MlBayesIm(BayesPm bayesPm, boolean countsOnly)
      Constructs an instance of MlBayesIm.
      Parameters:
      bayesPm - the BayesPm object that represents the Bayesian network.
      countsOnly - should be set to true for this constructor.
      Throws:
      IllegalArgumentException - if countsOnly is false.
      NullPointerException - if bayesPm is null.
    • MlBayesIm

      public MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, MlBayesIm.InitializationMethod initializationMethod) throws IllegalArgumentException
      Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values from the old BayesIm provided where posssible. If initialized manually, all values that cannot be retrieved from oldBayesIm will be set to Double.NaN ("?") in each such row; if initialized randomly, all values that cannot be retrieved from oldBayesIm will be distributed randomly in each such row.
      Parameters:
      bayesPm - the given Bayes PM. Carries with it the underlying graph model.
      oldBayesIm - an already-constructed BayesIm whose values may be used where possible to initialize this BayesIm. May be null.
      initializationMethod - either MANUAL or RANDOM.
      Throws:
      IllegalArgumentException - if the array of nodes provided is not a permutation of the nodes contained in the bayes parametric model provided.
    • MlBayesIm

      public MlBayesIm(BayesIm bayesIm) throws IllegalArgumentException
      Copy constructor.
      Parameters:
      bayesIm - a BayesIm object
      Throws:
      IllegalArgumentException - if any.
  • Method Details

    • serializableInstance

      public static MlBayesIm serializableInstance()
      Generates a simple exemplar of this class to test serialization.
      Returns:
      a MlBayesIm object
    • getParameterNames

      public static List<String> getParameterNames()

      getParameterNames.

      Returns:
      a List object
    • getBayesPm

      public BayesPm getBayesPm()

      Getter for the field bayesPm.

      Specified by:
      getBayesPm in interface BayesIm
      Returns:
      this PM.
    • getDag

      public Graph getDag()

      getDag.

      Specified by:
      getDag in interface BayesIm
      Returns:
      the DAG.
    • getNumNodes

      public int getNumNodes()

      getNumNodes.

      Specified by:
      getNumNodes in interface BayesIm
      Returns:
      the number of nodes in the model.
    • getNode

      public Node getNode(int nodeIndex)
      Retrieves the node at the specified index.
      Specified by:
      getNode in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      Returns:
      the node at the specified index.
    • getNode

      public Node getNode(String name)

      getNode.

      Specified by:
      getNode in interface BayesIm
      Parameters:
      name - the name of the node.
      Returns:
      the node.
    • getNodeIndex

      public int getNodeIndex(Node node)
      Returns the index of the given node in the nodes array.
      Specified by:
      getNodeIndex in interface BayesIm
      Parameters:
      node - the given node.
      Returns:
      the index of the node in the nodes array, or -1 if the node is not found.
    • getVariables

      public List<Node> getVariables()

      getVariables.

      Specified by:
      getVariables in interface BayesIm
      Specified by:
      getVariables in interface VariableSource
      Returns:
      a List object
    • getMeasuredNodes

      public List<Node> getMeasuredNodes()

      getMeasuredNodes.

      Specified by:
      getMeasuredNodes in interface BayesIm
      Returns:
      the list of measured variableNodes.
    • getVariableNames

      public List<String> getVariableNames()

      getVariableNames.

      Specified by:
      getVariableNames in interface BayesIm
      Specified by:
      getVariableNames in interface VariableSource
      Returns:
      a List object
    • getNumColumns

      public int getNumColumns(int nodeIndex)
      Returns the number of columns in the specified node.
      Specified by:
      getNumColumns in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      Returns:
      the number of columns.
      See Also:
    • getNumRows

      public int getNumRows(int nodeIndex)
      Retrieves the number of rows in the specified node.
      Specified by:
      getNumRows in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      Returns:
      the number of rows in the node.
      See Also:
    • getNumParents

      public int getNumParents(int nodeIndex)
      Returns the number of parents for the given node.
      Specified by:
      getNumParents in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      Returns:
      the number of parents.
    • getParent

      public int getParent(int nodeIndex, int parentIndex)
      Retrieves the parent of a node at the specified index.
      Specified by:
      getParent in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      parentIndex - the index of the parent.
      Returns:
      the parent of the node.
    • getParentDim

      public int getParentDim(int nodeIndex, int parentIndex)
      Retrieves the value of the parent dimension for a given node and parent index.
      Specified by:
      getParentDim in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      parentIndex - the index of the parent.
      Returns:
      the parent dimension value.
    • getParentDims

      public int[] getParentDims(int nodeIndex)
      Returns a copy of the dimensions of the parent node at the specified index.
      Specified by:
      getParentDims in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      Returns:
      an array containing the dimensions of the parent node.
      See Also:
    • getParents

      public int[] getParents(int nodeIndex)
      Returns an array containing the parents of the specified node.
      Specified by:
      getParents in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      Returns:
      an array of integers representing the parents of the specified node.
      See Also:
    • getParentValues

      public int[] getParentValues(int nodeIndex, int rowIndex)
      Returns an integer array containing the parent values for a given node index and row index.
      Specified by:
      getParentValues in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      rowIndex - the index of the row in question.
      Returns:
      an integer array containing the parent values.
      See Also:
    • getParentValue

      public int getParentValue(int nodeIndex, int rowIndex, int colIndex)
      Retrieves the value of the parent node at the specified row and column index.
      Specified by:
      getParentValue in interface BayesIm
      Parameters:
      nodeIndex - the index of the node.
      rowIndex - the index of the row in question.
      colIndex - the index of the column in question.
      Returns:
      the value of the parent node at the specified row and column index.
    • getProbability

      public double getProbability(int nodeIndex, int rowIndex, int colIndex)
      Returns the probability for a given node in the table.
      Specified by:
      getProbability in interface BayesIm
      Parameters:
      nodeIndex - the index of the node in question.
      rowIndex - the row in the table for this node which represents the combination of parent values in question.
      colIndex - the column in the table for this node which represents the value of the node in question.
      Returns:
      the probability value for the given node.
      See Also:
    • getRowIndex

      public int getRowIndex(int nodeIndex, int[] values)
      Returns the row index corresponding to the given node index and combination of parent values.
      Specified by:
      getRowIndex in interface BayesIm
      Parameters:
      nodeIndex - the index of the node in question.
      values - the combination of parent values in question.
      Returns:
      the row index corresponding to the given node index and combination of parent values.
      See Also:
    • normalizeAll

      public void normalizeAll()
      Normalizes all rows in the tables associated with each of node in turn.
      Specified by:
      normalizeAll in interface BayesIm
    • normalizeNode

      public void normalizeNode(int nodeIndex)
      Normalizes the specified node by invoking the normalizeRow(int, int) method on each row of the node.
      Specified by:
      normalizeNode in interface BayesIm
      Parameters:
      nodeIndex - the index of the node to be normalized.
    • normalizeRow

      public void normalizeRow(int nodeIndex, int rowIndex)
      Normalizes the probabilities of a given row in a node.
      Specified by:
      normalizeRow in interface BayesIm
      Parameters:
      nodeIndex - the index of the node in question.
      rowIndex - the index of the row in question.
    • setProbability

      public void setProbability(int nodeIndex, double[][] probMatrix)
      Sets the probability for the given node. The matrix row represent row index, the row in the table for this for node which represents the combination of parent values in question. of the CPT. The matrix column represent column index, the column in the table for this node which represents the value of the node in question.
      Specified by:
      setProbability in interface BayesIm
      Parameters:
      nodeIndex - The index of the node.
      probMatrix - The matrix of probabilities.
    • setCountMap

      public void setCountMap(int nodeIndex, CptMapCounts countMap)
      Sets the count map for a specific node index in the Bayesian network.
      Parameters:
      nodeIndex - the index of the node in the Bayesian network
      countMap - the count map to be set
      Throws:
      IllegalArgumentException - if the Bayesian network is not of type CptMapType.COUNT_MAP
    • setProbability

      public void setProbability(int nodeIndex, int rowIndex, int colIndex, double value)
      Sets the probability value for a specific node, row, and column in the probability table.
      Specified by:
      setProbability in interface BayesIm
      Parameters:
      nodeIndex - the index of the node in question.
      rowIndex - the row in the table for this node which represents the combination of parent values in question.
      colIndex - the column in the table for this node which represents the value of the node in question.
      value - the desired probability to be set. Must be between 0.0 and 1.0, or Double.NaN.
      Throws:
      IllegalArgumentException - if the column index is out of range for the given node, or if the probability value is not between 0.0 and 1.0 or Double.NaN.
      See Also:
    • getCorrespondingNodeIndex

      public int getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm)
      Returns the corresponding node index in the given BayesIm based on the node index in this BayesIm.
      Specified by:
      getCorrespondingNodeIndex in interface BayesIm
      Parameters:
      nodeIndex - the index of the node in this BayesIm.
      otherBayesIm - the BayesIm in which the node is to be found.
      Returns:
      the corresponding node index in the given BayesIm.
    • clearRow

      public void clearRow(int nodeIndex, int rowIndex)
      Clears all values in the specified row of a table.
      Specified by:
      clearRow in interface BayesIm
      Parameters:
      nodeIndex - the index of the node for the table that this row belongs to
      rowIndex - the index of the row to be cleared
    • randomizeRow

      public void randomizeRow(int nodeIndex, int rowIndex)
      Randomizes the values of a row in a table for a given node.
      Specified by:
      randomizeRow in interface BayesIm
      Parameters:
      nodeIndex - the index of the node for the table that this row belongs to.
      rowIndex - the index of the row to be randomized.
    • randomizeIncompleteRows

      public void randomizeIncompleteRows(int nodeIndex)
      Randomizes the incomplete rows in the specified node's table.
      Specified by:
      randomizeIncompleteRows in interface BayesIm
      Parameters:
      nodeIndex - the index of the node for the table whose incomplete rows are to be randomized
    • randomizeTable

      public void randomizeTable(int nodeIndex)
      Randomizes the table for a given node.
      Specified by:
      randomizeTable in interface BayesIm
      Parameters:
      nodeIndex - the index of the node for the table to be randomized
    • clearTable

      public void clearTable(int nodeIndex)
      Clears the table by clearing all rows for the given node.
      Specified by:
      clearTable in interface BayesIm
      Parameters:
      nodeIndex - The index of the node for the table to be cleared.
    • isIncomplete

      public boolean isIncomplete(int nodeIndex, int rowIndex)
      Checks if the specified row of a table is incomplete, i.e., if any of the columns have a NaN value.
      Specified by:
      isIncomplete in interface BayesIm
      Parameters:
      nodeIndex - the index of the table node to check.
      rowIndex - the index of the row to check.
      Returns:
      true if the row is incomplete, false otherwise.
    • isIncomplete

      public boolean isIncomplete(int nodeIndex)
      Checks if the specified table has any incomplete rows.
      Specified by:
      isIncomplete in interface BayesIm
      Parameters:
      nodeIndex - the index of the node for the table
      Returns:
      true if the table has any incomplete rows, false otherwise
    • simulateData

      public DataSet simulateData(int sampleSize, boolean latentDataSaved, int[] tiers)
      Simulates a sample with the given sample size.
      Parameters:
      sampleSize - the sample size.
      latentDataSaved - a boolean
      tiers - an array of int objects
      Returns:
      the simulated sample as a DataSet.
    • simulateData

      public DataSet simulateData(int sampleSize, boolean latentDataSaved)
      Simulates a data set.
      Specified by:
      simulateData in interface BayesIm
      Specified by:
      simulateData in interface Simulator
      Parameters:
      sampleSize - The number of rows to simulate.
      latentDataSaved - If set to true, latent variables are saved in the data set.
      Returns:
      The simulated data set.
      Throws:
      IllegalArgumentException - If the graph contains a directed cycle.
    • simulateData

      public DataSet simulateData(DataSet dataSet, boolean latentDataSaved, int[] tiers)

      simulateData.

      Parameters:
      dataSet - a DataSet object
      latentDataSaved - a boolean
      tiers - an array of int objects
      Returns:
      a DataSet object
    • simulateData

      public DataSet simulateData(DataSet dataSet, boolean latentDataSaved)
      Simulates data for the given data set.
      Specified by:
      simulateData in interface BayesIm
      Parameters:
      dataSet - The data set to simulate data for.
      latentDataSaved - Indicates whether latent data should be saved during simulation.
      Returns:
      The modified data set after simulating the data.
    • equals

      public boolean equals(Object o)
      Determines whether the specified object is equal to this Bayes net.
      Specified by:
      equals in interface BayesIm
      Overrides:
      equals in class Object
      Parameters:
      o - the object to be compared to this Bayes net
      Returns:
      true if the specified object is equal to this Bayes net, false otherwise
    • toString

      public String toString()
      Prints out the probability table for each variable.
      Specified by:
      toString in interface BayesIm
      Overrides:
      toString in class Object
      Returns:
      a String object
    • getCptMapType

      public MlBayesIm.CptMapType getCptMapType()
      A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used. The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility.
      Specified by:
      getCptMapType in interface BayesIm
      Returns:
      the CptMapType for this instance