Package org.tribuo.util.onnx
Class ONNXUtils
java.lang.Object
org.tribuo.util.onnx.ONNXUtils
Helper functions for building ONNX protos.
-
Method Summary
Modifier and TypeMethodDescriptionstatic OnnxMl.TensorProto
arrayBuilder
(ONNXContext context, String name, double[] parameters) Builds a TensorProto containing the array.static OnnxMl.TensorProto
arrayBuilder
(ONNXContext context, String name, double[] parameters, boolean downcast) Builds a TensorProto containing the array.static OnnxMl.TensorProto
arrayBuilder
(ONNXContext context, String name, float[] parameters) Builds a TensorProto containing the array.static OnnxMl.TensorProto
arrayBuilder
(ONNXContext context, String name, int[] parameters) Builds a TensorProto containing the array.static OnnxMl.TensorProto
arrayBuilder
(ONNXContext context, String name, long[] parameters) Builds a TensorProto containing the array.static OnnxMl.TypeProto
buildTensorTypeNode
(org.tribuo.util.onnx.ONNXShape shape, OnnxMl.TensorProto.DataType type) Builds a type proto for the specified shape and tensor type.static OnnxMl.TensorProto
doubleTensorBuilder
(ONNXContext context, String name, List<Integer> dims, Consumer<DoubleBuffer> dataPopulator) Generic method to create doubleOnnxMl.TensorProto
instances.static OnnxMl.TensorProto
floatTensorBuilder
(ONNXContext context, String name, List<Integer> dims, Consumer<FloatBuffer> dataPopulator) Generic method to create floatOnnxMl.TensorProto
instances.static OnnxMl.TensorProto
scalarBuilder
(ONNXContext context, String name, double value) Builds a TensorProto containing the scalar value.static OnnxMl.TensorProto
scalarBuilder
(ONNXContext context, String name, float value) Builds a TensorProto containing the scalar value.static OnnxMl.TensorProto
scalarBuilder
(ONNXContext context, String name, int value) Builds a TensorProto containing the scalar value.static OnnxMl.TensorProto
scalarBuilder
(ONNXContext context, String name, long value) Builds a TensorProto containing the scalar value.
-
Method Details
-
buildTensorTypeNode
public static OnnxMl.TypeProto buildTensorTypeNode(org.tribuo.util.onnx.ONNXShape shape, OnnxMl.TensorProto.DataType type) Builds a type proto for the specified shape and tensor type.- Parameters:
shape
- The shape.type
- The tensor type.- Returns:
- The type proto.
-
scalarBuilder
Builds a TensorProto containing the scalar value.- Parameters:
context
- The naming context.name
- The base name for the proto.value
- The value to store.- Returns:
- A TensorProto containing the value as an int.
-
scalarBuilder
Builds a TensorProto containing the scalar value.- Parameters:
context
- The naming context.name
- The base name for the proto.value
- The value to store.- Returns:
- A TensorProto containing the value as a long.
-
scalarBuilder
Builds a TensorProto containing the scalar value.- Parameters:
context
- The naming context.name
- The base name for the proto.value
- The value to store.- Returns:
- A TensorProto containing the value as a float.
-
scalarBuilder
Builds a TensorProto containing the scalar value.- Parameters:
context
- The naming context.name
- The base name for the proto.value
- The value to store.- Returns:
- A TensorProto containing the value as a double.
-
floatTensorBuilder
public static OnnxMl.TensorProto floatTensorBuilder(ONNXContext context, String name, List<Integer> dims, Consumer<FloatBuffer> dataPopulator) Generic method to create floatOnnxMl.TensorProto
instances.- Parameters:
context
- the naming context.name
- the base name for the proto.dims
- the dimensions of the input data.dataPopulator
- a method to populate aFloatBuffer
that will be written into the TensorProto's rawData field.- Returns:
- a float-typed TensorProto representation of the data.
-
doubleTensorBuilder
public static OnnxMl.TensorProto doubleTensorBuilder(ONNXContext context, String name, List<Integer> dims, Consumer<DoubleBuffer> dataPopulator) Generic method to create doubleOnnxMl.TensorProto
instances.Note that ONNX fp64 support is poor compared to fp32.
- Parameters:
context
- the naming context.name
- the base name for the proto.dims
- the dimensions of the input data.dataPopulator
- a method to populate aDoubleBuffer
that will be written into the TensorProto's rawData field.- Returns:
- a double-typed TensorProto representation of the data.
-
arrayBuilder
Builds a TensorProto containing the array.- Parameters:
context
- The naming context.name
- The base name for the proto.parameters
- The array to store in the proto.- Returns:
- A TensorProto containing the array as floats.
-
arrayBuilder
public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters) Builds a TensorProto containing the array.Downcasts the doubles into floats as ONNX's fp64 support is poor compared to fp32.
- Parameters:
context
- The naming context.name
- The base name for the proto.parameters
- The array to store in the proto.- Returns:
- A TensorProto containing the array as floats.
-
arrayBuilder
public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters, boolean downcast) Builds a TensorProto containing the array.Optionally downcasts the doubles into floats.
- Parameters:
context
- The naming context.name
- The base name for the proto.parameters
- The array to store in the proto.downcast
- Downcasts the doubles into floats.- Returns:
- A TensorProto containing the array as either floats or doubles.
-
arrayBuilder
Builds a TensorProto containing the array.- Parameters:
context
- The naming context.name
- The base name for the proto.parameters
- The array to store in the proto.- Returns:
- A TensorProto containing the array as ints.
-
arrayBuilder
Builds a TensorProto containing the array.- Parameters:
context
- The naming context.name
- The base name for the proto.parameters
- The array to store in the proto.- Returns:
- A TensorProto containing the array as longs.
-