001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.xgboost;
018
019import org.tribuo.Example;
020import org.tribuo.ImmutableOutputInfo;
021import org.tribuo.Prediction;
022import org.tribuo.classification.Label;
023import org.tribuo.common.xgboost.XGBoostOutputConverter;
024
025import java.util.ArrayList;
026import java.util.LinkedHashMap;
027import java.util.List;
028
029/**
030 * Converts XGBoost outputs into {@link Label} {@link Prediction}s.
031 */
032public final class XGBoostClassificationConverter implements XGBoostOutputConverter<Label> {
033    private static final long serialVersionUID = 1L;
034
035    public XGBoostClassificationConverter() {}
036
037    @Override
038    public boolean generatesProbabilities() {
039        return true;
040    }
041
042    @Override
043    public Prediction<Label> convertOutput(ImmutableOutputInfo<Label> info, List<float[]> probabilitiesList, int numValidFeatures, Example<Label> example) {
044        if (probabilitiesList.size() != 1) {
045            throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output.");
046        }
047        double maxScore = Double.NEGATIVE_INFINITY;
048        Label maxLabel = null;
049        LinkedHashMap<String,Label> probMap = new LinkedHashMap<>();
050        float[] probabilities = probabilitiesList.get(0);
051
052        for (int i = 0; i < probabilities.length; i++) {
053            String name = info.getOutput(i).getLabel();
054            Label label = new Label(name, probabilities[i]);
055            probMap.put(name, label);
056            if (label.getScore() > maxScore) {
057                maxScore = label.getScore();
058                maxLabel = label;
059            }
060        }
061
062        return new Prediction<>(maxLabel,probMap,numValidFeatures,example,true);
063    }
064
065    @Override
066    public List<Prediction<Label>> convertBatchOutput(ImmutableOutputInfo<Label> info, List<float[][]> probabilitiesList, int[] numValidFeatures, Example<Label>[] examples) {
067        if (probabilitiesList.size() != 1) {
068            throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output.");
069        }
070        float[][] probabilities = probabilitiesList.get(0);
071
072        List<Prediction<Label>> predictions = new ArrayList<>();
073        for (int i = 0; i < probabilities.length; i++) {
074            double maxScore = Double.NEGATIVE_INFINITY;
075            Label maxLabel = null;
076            LinkedHashMap<String, Label> probMap = new LinkedHashMap<>();
077            for (int j = 0; j < probabilities[i].length; j++) {
078                String name = info.getOutput(j).getLabel();
079                Label label = new Label(name, probabilities[i][j]);
080                probMap.put(name, label);
081                if (label.getScore() > maxScore) {
082                    maxScore = label.getScore();
083                    maxLabel = label;
084                }
085            }
086
087            predictions.add(new Prediction<>(maxLabel, probMap, numValidFeatures[i], examples[i], true));
088        }
089
090        return predictions;
091    }
092}