Package org.tribuo.util.onnx
Class ONNXContext
java.lang.Object
org.tribuo.util.onnx.ONNXContext
Context object used to scope and manage the creation of ONNX
OnnxMl.GraphProto
and OnnxMl.ModelProto
instances. A single instance of ONNXContext should be used to create an ONNX graph/model, and mixing instances of
ONNXContext or of ONNXRef
s produced by multiple ONNXContexts is not supported.
The ONNXContext has all of the logic needed to produce ONNX graphs, but is typically used explicitly to produce leaf
nodes of graphs (inputs, outputs, and weight matrices) that have more fluent interfaces to operation(ONNXOperator, List, List, Map)
.
Produced ONNX protobuf objects are encapsulated by instances of ONNXRef
and its subclasses.
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionCreates a float tensor for this ONNXContext, populated according to parameters.Creates a tensor for this ONNXContext, populated according to parameters.Creates a float tensor for this ONNXContext, populated according to parameters.Creates an int tensor for this ONNXContext, populated according to parameters.Creates a long tensor for this ONNXContext, populated according to parameters.assignTo
(RHS input, LHS output) Creates anONNXOperators.IDENTITY
node connectinginput
tooutput
, effectively permitting assignment of values.ai.onnx.proto.OnnxMl.GraphProto
Builds the ONNX graph represented by this context.Creates a float scalar constant for this ONNXContext.Creates a long scalar constant for this ONNXContext.floatInput
(int featureDimension) Creates an input node for this ONNXContext, with the name "input", of dimension [batch_size,featureDimension
], and of type float32.floatInput
(String name, int featureDimension) Creates an input node for this ONNXContext, with the given name, of dimension [batch_size,featureDimension
], and of type float32.floatOutput
(int outputDimension) Creates an output node for this ONNXContext, with the name "output", of dimension [batch_size,outputDimension
], and of type float32.floatOutput
(String name, int outputDimension) Creates an output node for this ONNXContext, with the given name, of dimension [batch_size,outputDimension
], and of type float32.floatTensor
(String baseName, List<Integer> dims, Consumer<FloatBuffer> populate) Creates a tensor for this ONNXContext, populated asONNXUtils.floatTensorBuilder(ONNXContext, String, List, Consumer)
.operation
(ONNXOperator op, List<T> inputs, String outputName) Method for creatingONNXNode
s fromONNXOperator
instances and inputs.Method for creatingONNXNode
s fromONNXOperator
and inputs.Base method for creatingONNXNode
s fromONNXOperator
and inputs.void
Sets the graph name.
-
Constructor Details
-
ONNXContext
public ONNXContext()Creates an empty ONNX context.
-
-
Method Details
-
operation
public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperator op, List<T> inputs, List<String> outputs, Map<String, Object> attributes) Base method for creatingONNXNode
s fromONNXOperator
and inputs. Returns an instance of ONNXNode for each output of the ONNXOperator. The graph elements created by the operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. This is the root method for constructing ONNXNodes which all other methods on ONNXContext andONNXRef
call.- Type Parameters:
T
- The ONNXRef type of inputs- Parameters:
op
- An ONNXOperator to add to the graph, takinginputs
as input.inputs
- A list ofONNXRef
s created by this instance of ONNXContext.outputs
- A list of names that the output nodes ofop
should take.attributes
- A map of attributes of the operation, passed toONNXOperator.build(ONNXContext, String, String, Map)
.- Returns:
- a list of
ONNXNode
s that are the output nodes ofop
.
-
operation
public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperator op, List<T> inputs, String outputName, Map<String, Object> attributes) Method for creatingONNXNode
s fromONNXOperator
and inputs. Returns a single ONNXNode and throws IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.- Type Parameters:
T
- The ONNXRef type of inputs- Parameters:
op
- An ONNXOperator to add to the graph, takinginputs
as input.inputs
- A list ofONNXRef
s created by this instance of ONNXContext.outputName
- Name that the output node ofop
should take.attributes
- A map of attributes of the operation, passed toONNXOperator.build(ONNXContext, String, String, Map)
.- Returns:
- An
ONNXNode
that is the output nodes ofop
.
-
operation
public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperator op, List<T> inputs, String outputName) Method for creatingONNXNode
s fromONNXOperator
instances and inputs. Returns a single ONNXNode and throws IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.- Type Parameters:
T
- The ONNXRef type of inputs- Parameters:
op
- An ONNXOperator to add to the graph, takinginputs
as input.inputs
- A list ofONNXRef
s created by this instance of ONNXContext.outputName
- Name that the output node ofop
should take.- Returns:
- An
ONNXNode
that is the output nodes ofop
.
-
assignTo
Creates anONNXOperators.IDENTITY
node connectinginput
tooutput
, effectively permitting assignment of values. -
floatInput
Creates an input node for this ONNXContext, with the given name, of dimension [batch_size,featureDimension
], and of type float32.- Parameters:
name
- The name for this input node.featureDimension
- the second dimension of this input node.- Returns:
- An
ONNXPlaceholder
instance representing this input node.
-
floatInput
Creates an input node for this ONNXContext, with the name "input", of dimension [batch_size,featureDimension
], and of type float32.- Parameters:
featureDimension
- the second dimension of this input node.- Returns:
- An
ONNXPlaceholder
instance representing this input node.
-
floatOutput
Creates an output node for this ONNXContext, with the given name, of dimension [batch_size,outputDimension
], and of type float32.- Parameters:
name
- the name for this output node.outputDimension
- The second dimension of this output node.- Returns:
- An
ONNXPlaceholder
instance representing this output node.
-
floatOutput
Creates an output node for this ONNXContext, with the name "output", of dimension [batch_size,outputDimension
], and of type float32.- Parameters:
outputDimension
- The second dimension of this output node.- Returns:
- An
ONNXPlaceholder
instance representing this output node.
-
floatTensor
public ONNXInitializer floatTensor(String baseName, List<Integer> dims, Consumer<FloatBuffer> populate) Creates a tensor for this ONNXContext, populated asONNXUtils.floatTensorBuilder(ONNXContext, String, List, Consumer)
.- Parameters:
baseName
- The name for this tensor in the ONNX graph.dims
- The dimensions of this tensor.populate
- A function populating theFloatBuffer
that backs this tensor.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
array
Creates a long tensor for this ONNXContext, populated according to parameters.- Parameters:
baseName
- The name for this tensor in the ONNX graph.parameters
- The long[] to populate the tensor.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
array
Creates an int tensor for this ONNXContext, populated according to parameters.- Parameters:
baseName
- The name for this tensor in the ONNX graph.parameters
- The int[] to populate the tensor.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
array
Creates a float tensor for this ONNXContext, populated according to parameters.- Parameters:
baseName
- The name for this tensor in the ONNX graph.parameters
- The float[] to populate the tensor.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
array
Creates a tensor for this ONNXContext, populated according to parameters.- Parameters:
baseName
- The name for this tensor in the ONNX graph.parameters
- The double[] to populate the tensor.downcast
- Whether to downcastparameters
to float32 in the ONNX graph.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
array
Creates a float tensor for this ONNXContext, populated according to parameters.As with
ONNXUtils.arrayBuilder(ONNXContext, String, double[], boolean)
the doubles will be downcast to float32.- Parameters:
baseName
- The name for this tensor in the ONNX graph.parameters
- The double[] to populate the tensor.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
constant
Creates a float scalar constant for this ONNXContext.- Parameters:
baseName
- The name for this constant in the ONNX graph.value
- The float to populate the constant.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
constant
Creates a long scalar constant for this ONNXContext.- Parameters:
baseName
- The name for this constant in the ONNX graph.value
- The long to populate the constant.- Returns:
- An
ONNXInitializer
instance representing this tensor.
-
setName
Sets the graph name.- Parameters:
name
- The graph name.
-
buildGraph
public ai.onnx.proto.OnnxMl.GraphProto buildGraph()Builds the ONNX graph represented by this context.- Returns:
- The ONNX graph proto.
-