001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.columnar.processors.field;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.data.columnar.ColumnarFeature;
023import org.tribuo.data.columnar.FieldProcessor;
024
025import java.util.ArrayList;
026import java.util.EnumSet;
027import java.util.List;
028import java.util.logging.Logger;
029import java.util.regex.Matcher;
030import java.util.regex.Pattern;
031import java.util.stream.Collectors;
032
033/**
034 * A {@link FieldProcessor} which applies a regex to a field and generates {@link ColumnarFeature}s based on the matches.
035 */
036public class RegexFieldProcessor implements FieldProcessor {
037    private static final Logger logger = Logger.getLogger(RegexFieldProcessor.class.getName());
038
039    private Pattern regex;
040
041    @Config(mandatory = true,description="Regex to apply to the field.")
042    private String regexString;
043
044    @Config(mandatory = true,description="The field name to read.")
045    private String fieldName;
046
047    @Config(mandatory = true,description="Matching mode.")
048    private EnumSet<Mode> modes;
049
050    /**
051     * Matching mode.
052     */
053    public enum Mode {
054        MATCH_ALL,
055        MATCH_CONTAINS,
056        GROUPS
057    }
058
059    /**
060     * For olcut.
061     */
062    private RegexFieldProcessor() {}
063
064    /**
065     * Constructs a field processor which emits features when the field value matches the supplied regex.
066     * @param fieldName The field name to read.
067     * @param regex The regex to use for matching.
068     * @param modes The matching mode.
069     */
070    public RegexFieldProcessor(String fieldName, Pattern regex, EnumSet<Mode> modes) {
071        this.regex = regex;
072        this.fieldName = fieldName;
073        this.regexString = regex.pattern();
074        this.modes = modes;
075    }
076
077    /**
078     * Constructs a field processor which emits features when the field value matches the supplied regex.
079     * <p>
080     * The regex is compiled on construction.
081     * @param fieldName The field name to read.
082     * @param regex The regex to use for matching.
083     * @param modes The matching mode.
084     */
085    public RegexFieldProcessor(String fieldName, String regex, EnumSet<Mode> modes) {
086        this(fieldName,Pattern.compile(regex),modes);
087    }
088
089    /**
090     * Used by the OLCUT configuration system, and should not be called by external code.
091     */
092    @Override
093    public void postConfig() {
094        this.regex = Pattern.compile(regexString);
095    }
096
097    @Override
098    public String getFieldName() {
099        return fieldName;
100    }
101
102    @Override
103    public List<ColumnarFeature> process(String value) {
104        List<ColumnarFeature> features = new ArrayList<>();
105        Matcher m = regex.matcher(value);
106        for (Mode mode : modes) {
107            switch (mode) {
108                case MATCH_ALL:
109                    if (m.matches()) {
110                        features.add(new ColumnarFeature(fieldName,"MATCHES_ALL", 1.0));
111                    }
112                    break;
113                case MATCH_CONTAINS:
114                    if (m.find()) {
115                        features.add(new ColumnarFeature(fieldName,"CONTAINS_MATCH", 1.0));
116                    }
117                    break;
118                case GROUPS:
119                    int i = 0;
120                    while (m.find()) {
121                        i++;
122                        features.add(new ColumnarFeature(fieldName, "GROUPS(" + m.group(i) + ")", 1.0));
123                    }
124                    break;
125            }
126        }
127        return features;
128    }
129
130    @Override
131    public GeneratedFeatureType getFeatureType() {
132        return GeneratedFeatureType.CATEGORICAL;
133    }
134
135    @Override
136    public RegexFieldProcessor copy(String newFieldName) {
137        return new RegexFieldProcessor(newFieldName, regex, EnumSet.copyOf(modes));
138    }
139
140    @Override
141    public String toString() {
142        return "RegexFieldProcessor(fieldName="+getFieldName()+",modes=" + modes.stream().map(Mode::name).sorted().collect(Collectors.joining(":"))+')';
143    }
144
145    @Override
146    public ConfiguredObjectProvenance getProvenance() {
147        return new ConfiguredObjectProvenanceImpl(this,"FieldProcessor");
148    }
149}