001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.columnar.processors.field; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.Feature; 023import org.tribuo.data.columnar.ColumnarFeature; 024import org.tribuo.data.columnar.FieldProcessor; 025import org.tribuo.data.text.TextPipeline; 026 027import java.util.ArrayList; 028import java.util.Collections; 029import java.util.List; 030 031/** 032 * A {@link FieldProcessor} which takes a text field and runs a {@link TextPipeline} on it 033 * to generate features. 034 */ 035public class TextFieldProcessor implements FieldProcessor { 036 037 /** 038 * The name of the field that values will be drawn from. 039 */ 040 @Config(mandatory = true,description="The field name to read.") 041 private String fieldName; 042 043 @Config(mandatory = true,description="Text processing pipeline to use.") 044 private TextPipeline pipeline; 045 046 /** 047 * Constructs a field processor which uses the supplied text pipeline to process 048 * the field value. 049 * @param fieldName The field name to read. 050 * @param pipeline The text processing pipeline to use. 051 */ 052 public TextFieldProcessor(String fieldName, TextPipeline pipeline) { 053 this.fieldName = fieldName; 054 this.pipeline = pipeline; 055 } 056 057 /** 058 * For olcut. 059 */ 060 private TextFieldProcessor() {} 061 062 @Override 063 public String getFieldName() { 064 return fieldName; 065 } 066 067 @Override 068 public List<ColumnarFeature> process(String value) { 069 if ((value == null) || (value.isEmpty())) { 070 return Collections.emptyList(); 071 } else { 072 return wrapFeatures(fieldName,pipeline.process("",value)); 073 } 074 } 075 076 @Override 077 public GeneratedFeatureType getFeatureType() { 078 return GeneratedFeatureType.TEXT; 079 } 080 081 /** 082 * Note: the copy shares the text pipeline with the original. This may induce multithreading issues if 083 * the underlying pipeline is not thread safe. Tribuo builtin pipelines are thread safe. 084 * @param newFieldName The new field name for the copy. 085 * @return A copy of this TextFieldProcessor with the new field name. 086 */ 087 @Override 088 public TextFieldProcessor copy(String newFieldName) { 089 return new TextFieldProcessor(newFieldName,pipeline); 090 } 091 092 /** 093 * Convert the {@link Feature}s from a text pipeline into {@link ColumnarFeature}s with the right field name. 094 * @param fieldName The field name to prepend. 095 * @param inputFeatures The features to convert. 096 * @return A list of columnar features. 097 */ 098 public static List<ColumnarFeature> wrapFeatures(String fieldName, List<Feature> inputFeatures) { 099 if (inputFeatures.isEmpty()) { 100 return Collections.emptyList(); 101 } else { 102 List<ColumnarFeature> list = new ArrayList<>(); 103 104 for (Feature f : inputFeatures) { 105 ColumnarFeature newF = new ColumnarFeature(fieldName, f.getName(), f.getValue()); 106 list.add(newF); 107 } 108 109 return list; 110 } 111 } 112 113 @Override 114 public String toString() { 115 return "TextFieldProcessor(fieldName=" + getFieldName() + ",textPipeline="+pipeline.toString()+")"; 116 } 117 118 @Override 119 public ConfiguredObjectProvenance getProvenance() { 120 return new ConfiguredObjectProvenanceImpl(this,"FieldProcessor"); 121 } 122}