001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.xgboost;
018
019import org.tribuo.Example;
020import org.tribuo.ImmutableOutputInfo;
021import org.tribuo.Output;
022import org.tribuo.Prediction;
023
024import java.io.Serializable;
025import java.util.List;
026
027/**
028 * Converts the output of XGBoost into the appropriate prediction type.
029 */
030public interface XGBoostOutputConverter<T extends Output<T>> extends Serializable {
031
032    /**
033     * Does this converter produce probabilities?
034     * @return True if it produces probabilities.
035     */
036    public boolean generatesProbabilities();
037
038    /**
039     * Converts a list of float arrays from XGBoost Boosters into a Tribuo {@link Prediction}.
040     * @param info The output info.
041     * @param probabilities The XGBoost output.
042     * @param numValidFeatures The number of valid features used in the prediction.
043     * @param example The example this prediction was generated from.
044     * @return The prediction object.
045     */
046    public Prediction<T> convertOutput(ImmutableOutputInfo<T> info, List<float[]> probabilities, int numValidFeatures, Example<T> example);
047
048    /**
049     * Converts a list of float arrays from XGBoost Boosters into a Tribuo {@link Prediction}.
050     * @param info The output info.
051     * @param probabilities The XGBoost output, list dimension is across models, first array dimension is across examples, second array dimension is across outputs.
052     * @param numValidFeatures The number of valid features used in each prediction.
053     * @param examples The examples these predictions were generated from.
054     * @return The prediction object.
055     */
056    public List<Prediction<T>> convertBatchOutput(ImmutableOutputInfo<T> info, List<float[][]> probabilities, int[] numValidFeatures, Example<T>[] examples);
057
058}