001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence; 018 019import org.tribuo.ImmutableFeatureMap; 020import org.tribuo.ImmutableOutputInfo; 021import org.tribuo.Prediction; 022import org.tribuo.classification.Label; 023import org.tribuo.provenance.ModelProvenance; 024import org.tribuo.sequence.SequenceExample; 025import org.tribuo.sequence.SequenceModel; 026 027import java.io.Serializable; 028import java.util.ArrayList; 029import java.util.List; 030 031/** 032 * A Sequence model which can provide confidence predictions for subsequence predictions. 033 * <p> 034 * Used to provide confidence scores on a per subsequence level. 035 * <p> 036 * The exemplar of this is providing a confidence score for each Named Entity present 037 * in a SequenceExample. 038 */ 039public abstract class ConfidencePredictingSequenceModel extends SequenceModel<Label> { 040 private static final long serialVersionUID = 1L; 041 042 protected ConfidencePredictingSequenceModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) { 043 super(name,description,featureIDMap,labelIDMap); 044 } 045 046 /** 047 * The scoring function for the subsequences. Provides the scores which should be assigned to each subsequence. 048 * @param example The input sequence example. 049 * @param predictions The predictions produced by this model. 050 * @param subsequences The subsequences to score. 051 * @param <SUB> The subsequence type. 052 * @return The scores for the subsequences. 053 */ 054 public abstract <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences); 055 056 /** 057 * A scoring method which multiplies together the per prediction scores. 058 * @param predictions The element level predictions. 059 * @param subsequences The subsequences denoting prediction boundaries. 060 * @param <SUB> The subsequence type. 061 * @return A list of scores for each subsequence. 062 */ 063 public static <SUB extends Subsequence> List<Double> multiplyWeights(List<Prediction<Label>> predictions, List<SUB> subsequences) { 064 List<Double> scores = new ArrayList<>(subsequences.size()); 065 for(Subsequence subsequence : subsequences) { 066 scores.add(multiplyWeights(predictions, subsequence)); 067 } 068 return scores; 069 } 070 071 private static <SUB extends Subsequence> Double multiplyWeights(List<Prediction<Label>> predictions, SUB subsequence) { 072 double counter = 1.0; 073 for (int i=subsequence.begin; i<subsequence.end; i++) { 074 counter *= predictions.get(i).getOutput().getScore(); 075 } 076 return counter; 077 } 078 079 /** 080 * A range class used to define a subsequence of a SequenceExample. 081 */ 082 public static class Subsequence implements Serializable { 083 private static final long serialVersionUID = 1L; 084 public final int begin; 085 public final int end; 086 087 /** 088 * Constructs a subsequence for the fixed range, exclusive of the end. 089 * @param begin The start element. 090 * @param end The end element. 091 */ 092 public Subsequence(int begin, int end) { 093 this.begin = begin; 094 this.end = end; 095 } 096 097 /** 098 * Returns the number of elements in this subsequence. 099 * @return The length of the subsequence. 100 */ 101 public int length() { 102 return end - begin; 103 } 104 105 @Override 106 public String toString() { 107 return "("+begin+","+end+")"; 108 } 109 } 110 111}