Class ONNXContext

java.lang.Object
org.tribuo.util.onnx.ONNXContext

public final class ONNXContext extends Object
Context object used to scope and manage the creation of ONNX OnnxMl.GraphProto and OnnxMl.ModelProto instances. A single instance of ONNXContext should be used to create an ONNX graph/model, and mixing instances of ONNXContext or of ONNXRefs produced by multiple ONNXContexts is not supported.

The ONNXContext has all of the logic needed to produce ONNX graphs, but is typically used explicitly to produce leaf nodes of graphs (inputs, outputs, and weight matrices) that have more fluent interfaces to operation(ONNXOperators, List, List, Map). Produced ONNX protobuf objects are encapsulated by instances of ONNXRef and its subclasses.

  • Constructor Details

    • ONNXContext

      public ONNXContext()
      Creates an empty ONNX context.
  • Method Details

    • operation

      public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperators op, List<T> inputs, List<String> outputs, Map<String,Object> attributes)
      Base method for creating ONNXNodes from ONNXOperators and inputs. Returns an instance of ONNXNode for each output of the ONNXOperator. The graph elements created by the operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. This is the root method for constructing ONNXNodes which all other methods on ONNXContext and ONNXRef call.
      Type Parameters:
      T - The ONNXRef type of inputs
      Parameters:
      op - An ONNXOperator to add to the graph, taking inputs as input.
      inputs - A list of ONNXRefs created by this instance of ONNXContext.
      outputs - A list of names that the output nodes of op should take.
      attributes - A map of attributes of the operation, passed to ONNXOperators.build(ONNXContext, String, String, Map).
      Returns:
      a list of ONNXNodes that are the output nodes of op.
    • operation

      public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> inputs, String outputName, Map<String,Object> attributes)
      Method for creating ONNXNodes from ONNXOperators and inputs. Returns a single ONNXNode and throws IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
      Type Parameters:
      T - The ONNXRef type of inputs
      Parameters:
      op - An ONNXOperator to add to the graph, taking inputs as input.
      inputs - A list of ONNXRefs created by this instance of ONNXContext.
      outputName - Name that the output node of op should take.
      attributes - A map of attributes of the operation, passed to ONNXOperators.build(ONNXContext, String, String, Map).
      Returns:
      An ONNXNode that is the output nodes of op.
    • operation

      public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> inputs, String outputName)
      Method for creating ONNXNodes from ONNXOperators and inputs. Returns a single ONNXNode and throws IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
      Type Parameters:
      T - The ONNXRef type of inputs
      Parameters:
      op - An ONNXOperator to add to the graph, taking inputs as input.
      inputs - A list of ONNXRefs created by this instance of ONNXContext.
      outputName - Name that the output node of op should take.
      Returns:
      An ONNXNode that is the output nodes of op.
    • assignTo

      public <LHS extends ONNXRef<?>, RHS extends ONNXRef<?>> LHS assignTo(RHS input, LHS output)
      Creates an ONNXOperators.IDENTITY node connecting input to output, effectively permitting assignment of values.
      Type Parameters:
      LHS - the ONNXRef type of the output.
      RHS - the ONNXRef type of the input.
      Parameters:
      input - The input node / right-hand side of the assignment.
      output - The output node / left-hand side of the assignment.
      Returns:
      the output node that was assigned to.
    • floatInput

      public ONNXPlaceholder floatInput(String name, int featureDimension)
      Creates an input node for this ONNXContext, with the given name, of dimension [batch_size, featureDimension], and of type float32.
      Parameters:
      name - The name for this input node.
      featureDimension - the second dimension of this input node.
      Returns:
      An ONNXPlaceholder instance representing this input node.
    • floatInput

      public ONNXPlaceholder floatInput(int featureDimension)
      Creates an input node for this ONNXContext, with the name "input", of dimension [batch_size, featureDimension], and of type float32.
      Parameters:
      featureDimension - the second dimension of this input node.
      Returns:
      An ONNXPlaceholder instance representing this input node.
    • floatOutput

      public ONNXPlaceholder floatOutput(String name, int outputDimension)
      Creates an output node for this ONNXContext, with the given name, of dimension [batch_size, outputDimension], and of type float32.
      Parameters:
      name - the name for this output node.
      outputDimension - The second dimension of this output node.
      Returns:
      An ONNXPlaceholder instance representing this output node.
    • floatOutput

      public ONNXPlaceholder floatOutput(int outputDimension)
      Creates an output node for this ONNXContext, with the name "output", of dimension [batch_size, outputDimension], and of type float32.
      Parameters:
      outputDimension - The second dimension of this output node.
      Returns:
      An ONNXPlaceholder instance representing this output node.
    • floatTensor

      public ONNXInitializer floatTensor(String baseName, List<Integer> dims, Consumer<FloatBuffer> populate)
      Creates a tensor for this ONNXContext, populated as ONNXUtils.floatTensorBuilder(ONNXContext, String, List, Consumer).
      Parameters:
      baseName - The name for this tensor in the ONNX graph.
      dims - The dimensions of this tensor.
      populate - A function populating the FloatBuffer that backs this tensor.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • array

      public ONNXInitializer array(String baseName, long[] parameters)
      Creates a long tensor for this ONNXContext, populated according to parameters.
      Parameters:
      baseName - The name for this tensor in the ONNX graph.
      parameters - The long[] to populate the tensor.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • array

      public ONNXInitializer array(String baseName, int[] parameters)
      Creates an int tensor for this ONNXContext, populated according to parameters.
      Parameters:
      baseName - The name for this tensor in the ONNX graph.
      parameters - The int[] to populate the tensor.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • array

      public ONNXInitializer array(String baseName, float[] parameters)
      Creates a float tensor for this ONNXContext, populated according to parameters.
      Parameters:
      baseName - The name for this tensor in the ONNX graph.
      parameters - The float[] to populate the tensor.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • array

      public ONNXInitializer array(String baseName, double[] parameters, boolean downcast)
      Creates a tensor for this ONNXContext, populated according to parameters.
      Parameters:
      baseName - The name for this tensor in the ONNX graph.
      parameters - The double[] to populate the tensor.
      downcast - Whether to downcast parameters to float32 in the ONNX graph.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • array

      public ONNXInitializer array(String baseName, double[] parameters)
      Creates a float tensor for this ONNXContext, populated according to parameters.

      As with ONNXUtils.arrayBuilder(ONNXContext, String, double[], boolean) the doubles will be downcast to float32.

      Parameters:
      baseName - The name for this tensor in the ONNX graph.
      parameters - The double[] to populate the tensor.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • constant

      public ONNXInitializer constant(String baseName, float value)
      Creates a float scalar constant for this ONNXContext.
      Parameters:
      baseName - The name for this constant in the ONNX graph.
      value - The float to populate the constant.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • constant

      public ONNXInitializer constant(String baseName, long value)
      Creates a long scalar constant for this ONNXContext.
      Parameters:
      baseName - The name for this constant in the ONNX graph.
      value - The long to populate the constant.
      Returns:
      An ONNXInitializer instance representing this tensor.
    • setName

      public void setName(String name)
      Sets the graph name.
      Parameters:
      name - The graph name.
    • buildGraph

      public OnnxMl.GraphProto buildGraph()
      Builds the ONNX graph represented by this context.
      Returns:
      The ONNX graph proto.