001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence;
018
019import org.tribuo.ImmutableFeatureMap;
020import org.tribuo.ImmutableOutputInfo;
021import org.tribuo.Prediction;
022import org.tribuo.classification.Label;
023import org.tribuo.provenance.ModelProvenance;
024import org.tribuo.sequence.SequenceExample;
025import org.tribuo.sequence.SequenceModel;
026
027import java.io.Serializable;
028import java.util.ArrayList;
029import java.util.List;
030
031/**
032 * A Sequence model which can provide confidence predictions for subsequence predictions.
033 * <p>
034 * Used to provide confidence scores on a per subsequence level.
035 * <p>
036 * The exemplar of this is providing a confidence score for each Named Entity present
037 * in a SequenceExample.
038 */
039public abstract class ConfidencePredictingSequenceModel extends SequenceModel<Label> {
040    private static final long serialVersionUID = 1L;
041
042    protected ConfidencePredictingSequenceModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
043        super(name,description,featureIDMap,labelIDMap);
044    }
045
046    /**
047     * The scoring function for the subsequences. Provides the scores which should be assigned to each subsequence.
048     * @param example The input sequence example.
049     * @param predictions The predictions produced by this model.
050     * @param subsequences The subsequences to score.
051     * @param <SUB> The subsequence type.
052     * @return The scores for the subsequences.
053     */
054    public abstract <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences);
055
056    /**
057     * A scoring method which multiplies together the per prediction scores.
058     * @param predictions The element level predictions.
059     * @param subsequences The subsequences denoting prediction boundaries.
060     * @param <SUB> The subsequence type.
061     * @return A list of scores for each subsequence.
062     */
063    public static <SUB extends Subsequence> List<Double> multiplyWeights(List<Prediction<Label>> predictions, List<SUB> subsequences) {
064        List<Double> scores = new ArrayList<>(subsequences.size());
065        for(Subsequence subsequence : subsequences) {
066            scores.add(multiplyWeights(predictions, subsequence));
067        }
068        return scores;
069    }
070
071    private static <SUB extends Subsequence> Double multiplyWeights(List<Prediction<Label>> predictions, SUB subsequence) {
072        double counter = 1.0;
073        for (int i=subsequence.begin; i<subsequence.end; i++) {
074            counter *= predictions.get(i).getOutput().getScore();
075        }
076        return counter;
077    }
078
079    /**
080     * A range class used to define a subsequence of a SequenceExample.
081     */
082    public static class Subsequence implements Serializable {
083        private static final long serialVersionUID = 1L;
084        public final int begin;
085        public final int end;
086
087        /**
088         * Constructs a subsequence for the fixed range, exclusive of the end.
089         * @param begin The start element.
090         * @param end The end element.
091         */
092        public Subsequence(int begin, int end) {
093            this.begin = begin;
094            this.end = end;
095        }
096
097        /**
098         * Returns the number of elements in this subsequence.
099         * @return The length of the subsequence.
100         */
101        public int length() {
102            return end - begin;
103        }
104
105        @Override
106        public String toString() {
107            return "("+begin+","+end+")";
108        }
109    }
110
111}