001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.text;
018
019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import com.oracle.labs.mlrg.olcut.config.Options;
022import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
023
024import java.io.BufferedOutputStream;
025import java.io.BufferedReader;
026import java.io.FileInputStream;
027import java.io.FileOutputStream;
028import java.io.IOException;
029import java.io.InputStreamReader;
030import java.io.OutputStreamWriter;
031import java.io.PrintWriter;
032import java.io.UnsupportedEncodingException;
033import java.nio.charset.StandardCharsets;
034import java.nio.file.Path;
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.Random;
038import java.util.logging.Handler;
039import java.util.logging.Level;
040import java.util.logging.Logger;
041
042/**
043 * Splits data in our standard text format into training and testing portions.
044 * <p>
045 * Checks all the lines are valid before splitting.
046 */
047public class SplitTextData {
048    private static final Logger logger = Logger.getLogger(SplitTextData.class.getName());
049
050    public static class TrainTestSplitOptions implements Options {
051        @Override
052        public String getOptionsDescription() {
053            return "Splits a standard text format dataset in two.";
054        }
055        @Option(charName='s',longName="split-fraction",usage="Split fraction.")
056        public float splitFraction;
057        @Option(charName='i',longName="input-file",usage="Input data file in standard text format.")
058        public Path inputPath;
059        @Option(charName='t',longName="training-output-file",usage="Output training data file.")
060        public Path trainPath;
061        @Option(charName='v',longName="validation-output-file",usage="Output validation data file.")
062        public Path validationPath;
063        @Option(charName='r',longName="rng-seed",usage="Seed for the RNG.")
064        public long seed = 1;
065    }
066
067    public static void main(String[] args) throws IOException {
068        
069        //
070        // Use the labs format logging.
071        for (Handler h : Logger.getLogger("").getHandlers()) {
072            h.setLevel(Level.ALL);
073            h.setFormatter(new LabsLogFormatter());
074            try {
075                h.setEncoding("utf-8");
076            } catch (SecurityException | UnsupportedEncodingException ex) {
077                logger.severe("Error setting output encoding");
078            }
079        }
080
081        TrainTestSplitOptions options = new TrainTestSplitOptions();
082        ConfigurationManager cm = new ConfigurationManager(args,options);
083
084        if ((options.inputPath == null) || (options.trainPath == null) || (options.validationPath == null) || (options.splitFraction < 0.0) || (options.splitFraction > 1.0)) {
085            System.out.println("Incorrect arguments");
086            System.out.println(cm.usage());
087            return;
088        }
089
090        int n = 0;
091        int validCounter = 0;
092        int invalidCounter = 0;
093
094        BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(options.inputPath.toFile()), StandardCharsets.UTF_8));
095        
096        PrintWriter trainOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.trainPath.toFile())),StandardCharsets.UTF_8));
097        PrintWriter testOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.validationPath.toFile())),StandardCharsets.UTF_8));
098
099        ArrayList<Line> lines = new ArrayList<>();
100        while (input.ready()) {
101            n++;
102            String line = input.readLine().trim();
103            if(line.isEmpty()) {
104                invalidCounter++;
105                continue;
106            }
107            String[] fields = line.split("##");
108            if(fields.length != 2) {
109                invalidCounter++;
110                logger.warning(String.format("Bad line in %s at %d: %s",
111                        options.inputPath, n, line.substring(Math.min(50, line.length()))));
112                continue;
113            }
114            String label = fields[0].trim().toUpperCase();
115            lines.add(new Line(label,fields[1]));
116            validCounter++;
117        }
118
119        input.close();
120
121        logger.info("Found " + validCounter + " valid examples, " + invalidCounter + " invalid examples out of " + n + " lines.");
122
123        int numTraining = Math.round(options.splitFraction * validCounter);
124        int numTesting = validCounter - numTraining;
125
126        logger.info("Outputting " + numTraining + " training examples, and " + numTesting + " testing examples, with a " + options.splitFraction + " split.");
127
128        Collections.shuffle(lines,new Random(options.seed));
129        for (int i = 0; i < numTraining; i++) {
130            trainOutput.println(lines.get(i));
131        }
132        for (int i = numTraining; i < validCounter; i++) {
133            testOutput.println(lines.get(i));
134        }
135
136        trainOutput.close();
137        testOutput.close();
138    }
139
140    private static class Line {
141        public final String label;
142        public final String text;
143
144        Line(String label, String text) {
145            this.label = label;
146            this.text = text;
147        }
148
149        public String toString() {
150            return label + "##" + text;
151        }
152    }
153
154}
155