001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.interop.tensorflow;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.Output;
025import org.tribuo.math.la.SparseVector;
026import org.tribuo.math.la.VectorTuple;
027import org.tensorflow.Tensor;
028
029import java.util.List;
030import java.util.logging.Logger;
031
032/**
033 * Converts a sparse example into a dense float vector, then wraps it in a {@link Tensor}.
034 */
035public class DenseTransformer<T extends Output<T>> implements ExampleTransformer<T> {
036    private static final long serialVersionUID = 1L;
037    private static final Logger logger = Logger.getLogger(DenseTransformer.class.getName());
038
039    /**
040     * Feature size beyond which a warning is generated (as ONNX requires dense features and large feature spaces are memory hungry).
041     */
042    public static final int THRESHOLD = 1000000;
043
044    /**
045     * Number of times the feature size warning should be printed.
046     */
047    public static final int WARNING_THRESHOLD = 10;
048
049    private int warningCount = 0;
050
051    public DenseTransformer() { }
052
053    float[] innerTransform(Example<T> example, ImmutableFeatureMap featureIDMap) {
054        if ((warningCount < WARNING_THRESHOLD) && (featureIDMap.size() > THRESHOLD)) {
055            logger.warning("Large dense example requested, featureIDMap.size() = " + featureIDMap.size() + ", example.size() = " + example.size());
056            warningCount++;
057        }
058        float[] output = new float[featureIDMap.size()];
059
060        for (Feature f : example) {
061            int id = featureIDMap.getID(f.getName());
062            if (id > -1) {
063                output[id] = (float) f.getValue();
064            }
065        }
066
067        return output;
068    }
069
070    float[] innerTransform(SparseVector vector) {
071        if ((warningCount < WARNING_THRESHOLD) && (vector.size() > THRESHOLD)) {
072            logger.warning("Large dense example requested, dimension = " + vector.size() + ", numActiveElements = " + vector.numActiveElements());
073            warningCount++;
074        }
075        float[] output = new float[vector.size()];
076
077        for (VectorTuple f : vector) {
078            output[f.index] = (float) f.value;
079        }
080
081        return output;
082    }
083
084    @Override
085    public Tensor<?> transform(Example<T> example, ImmutableFeatureMap featureIDMap) {
086        float[][] output = new float[1][];
087        output[0] = innerTransform(example,featureIDMap);
088        return Tensor.create(output);
089    }
090
091    @Override
092    public Tensor<?> transform(List<Example<T>> examples, ImmutableFeatureMap featureIDMap) {
093        float[][] output = new float[examples.size()][];
094
095        int i = 0;
096        for (Example<T> example : examples) {
097            output[i] = innerTransform(example,featureIDMap);
098            i++;
099        }
100
101        return Tensor.create(output);
102    }
103
104    @Override
105    public Tensor<?> transform(SparseVector vector) {
106        float[][] output = new float[1][];
107        output[0] = innerTransform(vector);
108        return Tensor.create(output);
109    }
110
111    @Override
112    public Tensor<?> transform(List<SparseVector> vectors) {
113        float[][] output = new float[vectors.size()][];
114
115        int i = 0;
116        for (SparseVector vector : vectors) {
117            output[i] = innerTransform(vector);
118            i++;
119        }
120
121        return Tensor.create(output);
122    }
123
124    @Override
125    public String toString() {
126        return "DenseTransformer()";
127    }
128
129    @Override
130    public ConfiguredObjectProvenance getProvenance() {
131        return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer");
132    }
133}