001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.interop.tensorflow; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.Feature; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.Output; 025import org.tribuo.math.la.SparseVector; 026import org.tribuo.math.la.VectorTuple; 027import org.tensorflow.Tensor; 028 029import java.util.List; 030import java.util.logging.Logger; 031 032/** 033 * Converts a sparse example into a dense float vector, then wraps it in a {@link Tensor}. 034 */ 035public class DenseTransformer<T extends Output<T>> implements ExampleTransformer<T> { 036 private static final long serialVersionUID = 1L; 037 private static final Logger logger = Logger.getLogger(DenseTransformer.class.getName()); 038 039 /** 040 * Feature size beyond which a warning is generated (as ONNX requires dense features and large feature spaces are memory hungry). 041 */ 042 public static final int THRESHOLD = 1000000; 043 044 /** 045 * Number of times the feature size warning should be printed. 046 */ 047 public static final int WARNING_THRESHOLD = 10; 048 049 private int warningCount = 0; 050 051 public DenseTransformer() { } 052 053 float[] innerTransform(Example<T> example, ImmutableFeatureMap featureIDMap) { 054 if ((warningCount < WARNING_THRESHOLD) && (featureIDMap.size() > THRESHOLD)) { 055 logger.warning("Large dense example requested, featureIDMap.size() = " + featureIDMap.size() + ", example.size() = " + example.size()); 056 warningCount++; 057 } 058 float[] output = new float[featureIDMap.size()]; 059 060 for (Feature f : example) { 061 int id = featureIDMap.getID(f.getName()); 062 if (id > -1) { 063 output[id] = (float) f.getValue(); 064 } 065 } 066 067 return output; 068 } 069 070 float[] innerTransform(SparseVector vector) { 071 if ((warningCount < WARNING_THRESHOLD) && (vector.size() > THRESHOLD)) { 072 logger.warning("Large dense example requested, dimension = " + vector.size() + ", numActiveElements = " + vector.numActiveElements()); 073 warningCount++; 074 } 075 float[] output = new float[vector.size()]; 076 077 for (VectorTuple f : vector) { 078 output[f.index] = (float) f.value; 079 } 080 081 return output; 082 } 083 084 @Override 085 public Tensor<?> transform(Example<T> example, ImmutableFeatureMap featureIDMap) { 086 float[][] output = new float[1][]; 087 output[0] = innerTransform(example,featureIDMap); 088 return Tensor.create(output); 089 } 090 091 @Override 092 public Tensor<?> transform(List<Example<T>> examples, ImmutableFeatureMap featureIDMap) { 093 float[][] output = new float[examples.size()][]; 094 095 int i = 0; 096 for (Example<T> example : examples) { 097 output[i] = innerTransform(example,featureIDMap); 098 i++; 099 } 100 101 return Tensor.create(output); 102 } 103 104 @Override 105 public Tensor<?> transform(SparseVector vector) { 106 float[][] output = new float[1][]; 107 output[0] = innerTransform(vector); 108 return Tensor.create(output); 109 } 110 111 @Override 112 public Tensor<?> transform(List<SparseVector> vectors) { 113 float[][] output = new float[vectors.size()][]; 114 115 int i = 0; 116 for (SparseVector vector : vectors) { 117 output[i] = innerTransform(vector); 118 i++; 119 } 120 121 return Tensor.create(output); 122 } 123 124 @Override 125 public String toString() { 126 return "DenseTransformer()"; 127 } 128 129 @Override 130 public ConfiguredObjectProvenance getProvenance() { 131 return new ConfiguredObjectProvenanceImpl(this,"ExampleTransformer"); 132 } 133}