001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.onnx;
018
019import ai.onnxruntime.OnnxTensor;
020import ai.onnxruntime.OrtEnvironment;
021import ai.onnxruntime.OrtException;
022import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
024import org.tribuo.math.la.SparseVector;
025import org.tribuo.math.la.VectorTuple;
026
027import java.util.List;
028import java.util.logging.Logger;
029
030/**
031 * Converts a sparse Tribuo example into a dense float vector, then wraps it in an {@link OnnxTensor}.
032 */
033public class DenseTransformer implements ExampleTransformer {
034    private static final long serialVersionUID = 1L;
035    private static final Logger logger = Logger.getLogger(DenseTransformer.class.getName());
036
037    /**
038     * Feature size beyond which a warning is generated (as ONNX requires dense features and large feature spaces are memory hungry).
039     */
040    public static final int THRESHOLD = 1000000;
041
042    /**
043     * Number of times the feature size warning should be printed.
044     */
045    public static final int WARNING_THRESHOLD = 10;
046
047    private int warningCount = 0;
048
049    private float[] innerTransform(SparseVector vector) {
050        if ((warningCount < WARNING_THRESHOLD) && (vector.size() > THRESHOLD)) {
051            logger.warning("Large dense example requested, dimension = " + vector.size() + ", numActiveElements = " + vector.numActiveElements());
052            warningCount++;
053        }
054        float[] output = new float[vector.size()];
055
056        for (VectorTuple f : vector) {
057            output[f.index] = (float) f.value;
058        }
059
060        return output;
061    }
062
063    @Override
064    public OnnxTensor transform(OrtEnvironment env, SparseVector vector) throws OrtException {
065        float[][] output = new float[1][];
066        output[0] = innerTransform(vector);
067        return OnnxTensor.createTensor(env,output);
068    }
069
070    @Override
071    public OnnxTensor transform(OrtEnvironment env, List<SparseVector> vectors) throws OrtException {
072        float[][] output = new float[vectors.size()][];
073
074        int i = 0;
075        for (SparseVector vector : vectors) {
076            output[i] = innerTransform(vector);
077            i++;
078        }
079
080        return OnnxTensor.createTensor(env,output);
081    }
082
083    @Override
084    public String toString() {
085        return "DenseTransformer()";
086    }
087
088    @Override
089    public ConfiguredObjectProvenance getProvenance() {
090        return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer");
091    }
092}