001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenance; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.ImmutableOutputInfo; 022import org.tribuo.MutableOutputInfo; 023import org.tribuo.OutputFactory; 024import org.tribuo.classification.Label; 025import org.tribuo.classification.LabelFactory; 026import org.tribuo.evaluation.Evaluator; 027import org.tribuo.multilabel.evaluation.MultiLabelEvaluation; 028import org.tribuo.multilabel.evaluation.MultiLabelEvaluator; 029import org.tribuo.provenance.OutputFactoryProvenance; 030 031import java.util.ArrayList; 032import java.util.Collection; 033import java.util.List; 034import java.util.Map; 035import java.util.Set; 036 037/** 038 * A factory for generating MultiLabel objects and their associated OutputInfo and Evaluator objects. 039 */ 040public final class MultiLabelFactory implements OutputFactory<MultiLabel> { 041 private static final long serialVersionUID = 1L; 042 043 public static final MultiLabel UNKNOWN_MULTILABEL = new MultiLabel(LabelFactory.UNKNOWN_LABEL); 044 045 private static final MultiLabelFactoryProvenance provenance = new MultiLabelFactoryProvenance(); 046 047 private static final MultiLabelEvaluator evaluator = new MultiLabelEvaluator(); 048 049 /** 050 * Construct a MultiLabelFactory. 051 */ 052 public MultiLabelFactory() {} 053 054 /** 055 * Parses the MultiLabel value either by toStringing the input and calling {@link MultiLabel#parseString} 056 * or if it's a {@link Collection} iterating over the elements calling toString on each element in turn and using 057 * {@link MultiLabel#parseElement}. 058 * @param label An input value. 059 * @param <V> The type of the input value. 060 * @return A MultiLabel 061 */ 062 @Override 063 public <V> MultiLabel generateOutput(V label) { 064 if (label instanceof Collection) { 065 Collection<?> c = (Collection<?>) label; 066 List<Pair<String,Boolean>> dimensions = new ArrayList<>(); 067 for (Object o : c) { 068 dimensions.add(MultiLabel.parseElement(o.toString())); 069 } 070 return MultiLabel.createFromPairList(dimensions); 071 } 072 return MultiLabel.parseString(label.toString()); 073 } 074 075 @Override 076 public MultiLabel getUnknownOutput() { 077 return UNKNOWN_MULTILABEL; 078 } 079 080 @Override 081 public MutableOutputInfo<MultiLabel> generateInfo() { 082 return new MutableMultiLabelInfo(); 083 } 084 085 @Override 086 public ImmutableOutputInfo<MultiLabel> constructInfoForExternalModel(Map<MultiLabel,Integer> mapping) { 087 // Validate inputs are dense 088 OutputFactory.validateMapping(mapping); 089 090 MutableMultiLabelInfo info = new MutableMultiLabelInfo(); 091 092 for (Map.Entry<MultiLabel,Integer> e : mapping.entrySet()) { 093 info.observe(e.getKey()); 094 } 095 096 return new ImmutableMultiLabelInfo(info,mapping); 097 } 098 099 @Override 100 public Evaluator<MultiLabel, MultiLabelEvaluation> getEvaluator() { 101 return evaluator; 102 } 103 104 @Override 105 public int hashCode() { 106 return "MultiLabelFactory".hashCode(); 107 } 108 109 @Override 110 public boolean equals(Object obj) { 111 return obj instanceof MultiLabelFactory; 112 } 113 114 @Override 115 public OutputFactoryProvenance getProvenance() { 116 return provenance; 117 } 118 119 /** 120 * Generates a comma separated string of labels from a {@link Set} of {@link Label}. 121 * @param input A Set of Label objects. 122 * @return A (possibly empty) comma separated string. 123 */ 124 public static String generateLabelString(Set<Label> input) { 125 if (input.isEmpty()) { 126 return ""; 127 } 128 List<String> list = new ArrayList<>(); 129 for (Label l : input) { 130 list.add(l.getLabel()); 131 } 132 list.sort(String::compareTo); 133 134 StringBuilder builder = new StringBuilder(); 135 for (String s : list) { 136 if (s.contains(",")) { 137 throw new IllegalStateException("MultiLabel cannot contain a label with a ',', found " + s + "."); 138 } 139 builder.append(s); 140 builder.append(','); 141 } 142 builder.deleteCharAt(builder.length()-1); 143 return builder.toString(); 144 } 145 146 /** 147 * Provenance for {@link MultiLabelFactory}. 148 */ 149 public final static class MultiLabelFactoryProvenance implements OutputFactoryProvenance { 150 private static final long serialVersionUID = 1L; 151 152 /** 153 * Constructs a multi-label factory provenance. 154 */ 155 MultiLabelFactoryProvenance() {} 156 157 /** 158 * Constructs a multi-label factory provenance from the empty marshalled form. 159 * @param map An empty map. 160 */ 161 public MultiLabelFactoryProvenance(Map<String, Provenance> map) { } 162 163 @Override 164 public String getClassName() { 165 return MultiLabelFactory.class.getName(); 166 } 167 168 @Override 169 public String toString() { 170 return generateString("OutputFactory"); 171 } 172 173 @Override 174 public boolean equals(Object other) { 175 return other instanceof MultiLabelFactoryProvenance; 176 } 177 178 @Override 179 public int hashCode() { 180 return 31; 181 } 182 } 183}