001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.text; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 023 024import java.io.BufferedOutputStream; 025import java.io.BufferedReader; 026import java.io.FileInputStream; 027import java.io.FileOutputStream; 028import java.io.IOException; 029import java.io.InputStreamReader; 030import java.io.OutputStreamWriter; 031import java.io.PrintWriter; 032import java.io.UnsupportedEncodingException; 033import java.nio.charset.StandardCharsets; 034import java.nio.file.Path; 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.Random; 038import java.util.logging.Handler; 039import java.util.logging.Level; 040import java.util.logging.Logger; 041 042/** 043 * Splits data in our standard text format into training and testing portions. 044 * <p> 045 * Checks all the lines are valid before splitting. 046 */ 047public class SplitTextData { 048 private static final Logger logger = Logger.getLogger(SplitTextData.class.getName()); 049 050 public static class TrainTestSplitOptions implements Options { 051 @Override 052 public String getOptionsDescription() { 053 return "Splits a standard text format dataset in two."; 054 } 055 @Option(charName='s',longName="split-fraction",usage="Split fraction.") 056 public float splitFraction; 057 @Option(charName='i',longName="input-file",usage="Input data file in standard text format.") 058 public Path inputPath; 059 @Option(charName='t',longName="training-output-file",usage="Output training data file.") 060 public Path trainPath; 061 @Option(charName='v',longName="validation-output-file",usage="Output validation data file.") 062 public Path validationPath; 063 @Option(charName='r',longName="rng-seed",usage="Seed for the RNG.") 064 public long seed = 1; 065 } 066 067 public static void main(String[] args) throws IOException { 068 069 // 070 // Use the labs format logging. 071 for (Handler h : Logger.getLogger("").getHandlers()) { 072 h.setLevel(Level.ALL); 073 h.setFormatter(new LabsLogFormatter()); 074 try { 075 h.setEncoding("utf-8"); 076 } catch (SecurityException | UnsupportedEncodingException ex) { 077 logger.severe("Error setting output encoding"); 078 } 079 } 080 081 TrainTestSplitOptions options = new TrainTestSplitOptions(); 082 ConfigurationManager cm = new ConfigurationManager(args,options); 083 084 if ((options.inputPath == null) || (options.trainPath == null) || (options.validationPath == null) || (options.splitFraction < 0.0) || (options.splitFraction > 1.0)) { 085 System.out.println("Incorrect arguments"); 086 System.out.println(cm.usage()); 087 return; 088 } 089 090 int n = 0; 091 int validCounter = 0; 092 int invalidCounter = 0; 093 094 BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(options.inputPath.toFile()), StandardCharsets.UTF_8)); 095 096 PrintWriter trainOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.trainPath.toFile())),StandardCharsets.UTF_8)); 097 PrintWriter testOutput = new PrintWriter(new OutputStreamWriter(new BufferedOutputStream(new FileOutputStream(options.validationPath.toFile())),StandardCharsets.UTF_8)); 098 099 ArrayList<Line> lines = new ArrayList<>(); 100 while (input.ready()) { 101 n++; 102 String line = input.readLine().trim(); 103 if(line.isEmpty()) { 104 invalidCounter++; 105 continue; 106 } 107 String[] fields = line.split("##"); 108 if(fields.length != 2) { 109 invalidCounter++; 110 logger.warning(String.format("Bad line in %s at %d: %s", 111 options.inputPath, n, line.substring(Math.min(50, line.length())))); 112 continue; 113 } 114 String label = fields[0].trim().toUpperCase(); 115 lines.add(new Line(label,fields[1])); 116 validCounter++; 117 } 118 119 input.close(); 120 121 logger.info("Found " + validCounter + " valid examples, " + invalidCounter + " invalid examples out of " + n + " lines."); 122 123 int numTraining = Math.round(options.splitFraction * validCounter); 124 int numTesting = validCounter - numTraining; 125 126 logger.info("Outputting " + numTraining + " training examples, and " + numTesting + " testing examples, with a " + options.splitFraction + " split."); 127 128 Collections.shuffle(lines,new Random(options.seed)); 129 for (int i = 0; i < numTraining; i++) { 130 trainOutput.println(lines.get(i)); 131 } 132 for (int i = numTraining; i < validCounter; i++) { 133 testOutput.println(lines.get(i)); 134 } 135 136 trainOutput.close(); 137 testOutput.close(); 138 } 139 140 private static class Line { 141 public final String label; 142 public final String text; 143 144 Line(String label, String text) { 145 this.label = label; 146 this.text = text; 147 } 148 149 public String toString() { 150 return label + "##" + text; 151 } 152 } 153 154} 155