Package edu.cmu.tetrad.search.utils
Class MultiLayerPerceptronDjl
java.lang.Object
edu.cmu.tetrad.search.utils.MultiLayerPerceptronDjl
The MultiLayerPerceptronDjl class provides a customizable implementation of a Multi-Layer Perceptron (MLP) for tasks
like regression or classification using the Deep Java Library (DJL). This class allows the user to define the network
architecture, including the input dimension, hidden layers, and type of output.
-
Constructor Summary
ConstructorsConstructorDescriptionMultiLayerPerceptronDjl(int inputDim, List<Integer> hiddenLayers, String variableType, float inputScale) Constructs a MultiLayerPerceptronDjl object with the specified input dimension, hidden layers, variable type, and input scaling factor. -
Method Summary
Modifier and TypeMethodDescriptionai.djl.ndarray.NDArrayforward(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.NDArray input) Computes the forward pass of the neural network for a given input.ai.djl.ndarray.NDManagerReturns the NDManager used for managing computational resources.
-
Constructor Details
-
MultiLayerPerceptronDjl
public MultiLayerPerceptronDjl(int inputDim, List<Integer> hiddenLayers, String variableType, float inputScale) Constructs a MultiLayerPerceptronDjl object with the specified input dimension, hidden layers, variable type, and input scaling factor. This builds the architecture of a neural network based on provided configurations such as the number of input features, hidden layer specifications, and the output type (e.g., continuous, multinomial, or binary).- Parameters:
inputDim- the number of input features or dimensions.hiddenLayers- a list of integers defining the number of neurons in each hidden layer.variableType- the type of prediction target, such as "continuous", "binary", or "multinomial". For multinomial, it should specify the number of categories as "multinomial,numCategories".inputScale- a scaling factor applied to the input data.
-
-
Method Details
-
forward
public ai.djl.ndarray.NDArray forward(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.NDArray input) throws ai.djl.translate.TranslateException Computes the forward pass of the neural network for a given input.- Parameters:
manager- theNDManagerused to manage the computational resources.input- the inputNDArrayto process through the neural network.- Returns:
- the resulting
NDArrayafter the forward pass through the network. - Throws:
ai.djl.translate.TranslateException- if there is an issue during computation or data translation.
-
getManager
public ai.djl.ndarray.NDManager getManager()Returns the NDManager used for managing computational resources.- Returns:
- the NDManager instance.
-