001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.xgboost; 018 019import org.tribuo.Example; 020import org.tribuo.ImmutableOutputInfo; 021import org.tribuo.Prediction; 022import org.tribuo.common.xgboost.XGBoostOutputConverter; 023import org.tribuo.regression.Regressor; 024 025import java.util.ArrayList; 026import java.util.List; 027 028/** 029 * Converts XGBoost outputs into {@link Regressor} {@link Prediction}s. 030 */ 031public final class XGBoostRegressionConverter implements XGBoostOutputConverter<Regressor> { 032 private static final long serialVersionUID = 1L; 033 034 @Override 035 public boolean generatesProbabilities() { 036 return false; 037 } 038 039 @Override 040 public Prediction<Regressor> convertOutput(ImmutableOutputInfo<Regressor> info, List<float[]> probabilities, int numValidFeatures, Example<Regressor> example) { 041 Regressor.DimensionTuple[] tuples = new Regressor.DimensionTuple[probabilities.size()]; 042 int i = 0; 043 for (float[] f : probabilities) { 044 tuples[i] = new Regressor.DimensionTuple(info.getOutput(i).getNames()[0],f[0]); 045 i++; 046 } 047 return new Prediction<>(new Regressor(tuples),numValidFeatures,example); 048 } 049 050 @Override 051 public List<Prediction<Regressor>> convertBatchOutput(ImmutableOutputInfo<Regressor> info, List<float[][]> probabilities, int[] numValidFeatures, Example<Regressor>[] examples) { 052 if ((numValidFeatures.length != examples.length) || (probabilities.get(0).length != numValidFeatures.length)) { 053 throw new IllegalArgumentException("Lengths not the same, numValidFeatures.length = " 054 + numValidFeatures.length + ", examples.length = " + examples.length 055 + ", probabilities.get(0).length = " + probabilities.get(0).length); 056 } 057 Regressor.DimensionTuple[][] tuples = new Regressor.DimensionTuple[numValidFeatures.length][probabilities.size()]; 058 int i = 0; 059 for (float[][] f : probabilities) { 060 String curName = info.getOutput(i).getNames()[0]; 061 for (int j = 0; j < numValidFeatures.length; j++) { 062 tuples[j][i] = new Regressor.DimensionTuple(curName, f[j][0]); 063 } 064 i++; 065 } 066 List<Prediction<Regressor>> predictions = new ArrayList<>(); 067 for (i = 0; i < numValidFeatures.length; i++) { 068 predictions.add(new Prediction<>(new Regressor(tuples[i]),numValidFeatures[i],examples[i])); 069 } 070 return predictions; 071 } 072 073}