001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.xgboost; 018 019import org.tribuo.Example; 020import org.tribuo.ImmutableOutputInfo; 021import org.tribuo.Prediction; 022import org.tribuo.classification.Label; 023import org.tribuo.common.xgboost.XGBoostOutputConverter; 024 025import java.util.ArrayList; 026import java.util.LinkedHashMap; 027import java.util.List; 028 029/** 030 * Converts XGBoost outputs into {@link Label} {@link Prediction}s. 031 */ 032public final class XGBoostClassificationConverter implements XGBoostOutputConverter<Label> { 033 private static final long serialVersionUID = 1L; 034 035 public XGBoostClassificationConverter() {} 036 037 @Override 038 public boolean generatesProbabilities() { 039 return true; 040 } 041 042 @Override 043 public Prediction<Label> convertOutput(ImmutableOutputInfo<Label> info, List<float[]> probabilitiesList, int numValidFeatures, Example<Label> example) { 044 if (probabilitiesList.size() != 1) { 045 throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output."); 046 } 047 double maxScore = Double.NEGATIVE_INFINITY; 048 Label maxLabel = null; 049 LinkedHashMap<String,Label> probMap = new LinkedHashMap<>(); 050 float[] probabilities = probabilitiesList.get(0); 051 052 for (int i = 0; i < probabilities.length; i++) { 053 String name = info.getOutput(i).getLabel(); 054 Label label = new Label(name, probabilities[i]); 055 probMap.put(name, label); 056 if (label.getScore() > maxScore) { 057 maxScore = label.getScore(); 058 maxLabel = label; 059 } 060 } 061 062 return new Prediction<>(maxLabel,probMap,numValidFeatures,example,true); 063 } 064 065 @Override 066 public List<Prediction<Label>> convertBatchOutput(ImmutableOutputInfo<Label> info, List<float[][]> probabilitiesList, int[] numValidFeatures, Example<Label>[] examples) { 067 if (probabilitiesList.size() != 1) { 068 throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output."); 069 } 070 float[][] probabilities = probabilitiesList.get(0); 071 072 List<Prediction<Label>> predictions = new ArrayList<>(); 073 for (int i = 0; i < probabilities.length; i++) { 074 double maxScore = Double.NEGATIVE_INFINITY; 075 Label maxLabel = null; 076 LinkedHashMap<String, Label> probMap = new LinkedHashMap<>(); 077 for (int j = 0; j < probabilities[i].length; j++) { 078 String name = info.getOutput(j).getLabel(); 079 Label label = new Label(name, probabilities[i][j]); 080 probMap.put(name, label); 081 if (label.getScore() > maxScore) { 082 maxScore = label.getScore(); 083 maxLabel = label; 084 } 085 } 086 087 predictions.add(new Prediction<>(maxLabel, probMap, numValidFeatures[i], examples[i], true)); 088 } 089 090 return predictions; 091 } 092}