001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sgd; 018 019import org.tribuo.math.la.SparseVector; 020 021import java.util.SplittableRandom; 022 023/** 024 * SGD utilities. Currently stores methods for shuffling examples and their associated labels and weights. 025 */ 026public class Util { 027 /** 028 * In place shuffle of the features, labels and weights. 029 * @param features Input features. 030 * @param labels Input labels. 031 * @param weights Input weights. 032 * @param rng SplittableRandom number generator. 033 */ 034 public static void shuffleInPlace(SparseVector[] features, int[] labels, double[] weights, SplittableRandom rng) { 035 int size = features.length; 036 // Shuffle array 037 for (int i = size; i > 1; i--) { 038 int j = rng.nextInt(i); 039 //swap features 040 SparseVector tmpFeature = features[i-1]; 041 features[i-1] = features[j]; 042 features[j] = tmpFeature; 043 //swap labels 044 int tmpLabel = labels[i-1]; 045 labels[i-1] = labels[j]; 046 labels[j] = tmpLabel; 047 //swap weights 048 double tmpWeight = weights[i-1]; 049 weights[i-1] = weights[j]; 050 weights[j] = tmpWeight; 051 } 052 } 053 054 /** 055 * In place shuffle of the features, labels, weights and indices. 056 * @param features Input features. 057 * @param labels Input labels. 058 * @param weights Input weights. 059 * @param indices Input indices. 060 * @param rng SplittableRandom number generator. 061 */ 062 public static void shuffleInPlace(SparseVector[] features, int[] labels, double[] weights, int[] indices, SplittableRandom rng) { 063 int size = features.length; 064 // Shuffle array 065 for (int i = size; i > 1; i--) { 066 int j = rng.nextInt(i); 067 //swap features 068 SparseVector tmpFeature = features[i-1]; 069 features[i-1] = features[j]; 070 features[j] = tmpFeature; 071 //swap labels 072 int tmpLabel = labels[i-1]; 073 labels[i-1] = labels[j]; 074 labels[j] = tmpLabel; 075 //swap weights 076 double tmpWeight = weights[i-1]; 077 weights[i-1] = weights[j]; 078 weights[j] = tmpWeight; 079 //swap indices 080 int tmpIndex = indices[i-1]; 081 indices[i-1] = indices[j]; 082 indices[j] = tmpIndex; 083 } 084 } 085 086 /** 087 * Shuffles the features, labels and weights returning a tuple of the shuffled inputs. 088 * @param features Input features. 089 * @param labels Input labels. 090 * @param weights Input weights. 091 * @param rng SplittableRandom number generator. 092 * @return A tuple of shuffled features, labels and weights. 093 */ 094 public static ExampleArray shuffle(SparseVector[] features, int[] labels, double[] weights, SplittableRandom rng) { 095 int size = features.length; 096 SparseVector[] newFeatures = new SparseVector[size]; 097 int[] newLabels = new int[size]; 098 double[] newWeights = new double[size]; 099 for (int i = 0; i < newFeatures.length; i++) { 100 newFeatures[i] = features[i]; 101 newLabels[i] = labels[i]; 102 newWeights[i] = weights[i]; 103 } 104 // Shuffle array 105 for (int i = size; i > 1; i--) { 106 int j = rng.nextInt(i); 107 //swap features 108 SparseVector tmpFeature = newFeatures[i-1]; 109 newFeatures[i-1] = newFeatures[j]; 110 newFeatures[j] = tmpFeature; 111 //swap labels 112 int tmpLabel = newLabels[i-1]; 113 newLabels[i-1] = newLabels[j]; 114 newLabels[j] = tmpLabel; 115 //swap weights 116 double tmpWeight = newWeights[i-1]; 117 newWeights[i-1] = newWeights[j]; 118 newWeights[j] = tmpWeight; 119 } 120 return new ExampleArray(newFeatures,newLabels,newWeights); 121 } 122 123 /** 124 * A nominal tuple. One day it'll be a record, but not today. 125 */ 126 public static class ExampleArray { 127 public final SparseVector[] features; 128 public final int[] labels; 129 public final double[] weights; 130 131 public ExampleArray(SparseVector[] features, int[] labels, double[] weights) { 132 this.features = features; 133 this.labels = labels; 134 this.weights = weights; 135 } 136 } 137 138 /** 139 * In place shuffle used for sequence problems. 140 * @param features Input features. 141 * @param labels Input labels. 142 * @param weights Input weights. 143 * @param rng SplittableRandom number generator. 144 */ 145 public static void shuffleInPlace(SparseVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) { 146 int size = features.length; 147 // Shuffle array 148 for (int i = size; i > 1; i--) { 149 int j = rng.nextInt(i); 150 //swap features 151 SparseVector[] tmpFeature = features[i-1]; 152 features[i-1] = features[j]; 153 features[j] = tmpFeature; 154 //swap labels 155 int[] tmpLabel = labels[i-1]; 156 labels[i-1] = labels[j]; 157 labels[j] = tmpLabel; 158 //swap weights 159 double tmpWeight = weights[i-1]; 160 weights[i-1] = weights[j]; 161 weights[j] = tmpWeight; 162 } 163 } 164 165 /** 166 * Shuffles a sequence of features, labels and weights, returning a tuple of the shuffled values. 167 * @param features Input features. 168 * @param labels Input labels. 169 * @param weights Input weights. 170 * @param rng SplittableRandom number generator. 171 * @return A tuple of shuffled features, labels and weights. 172 */ 173 public static SequenceExampleArray shuffle(SparseVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) { 174 int size = features.length; 175 SparseVector[][] newFeatures = new SparseVector[size][]; 176 int[][] newLabels = new int[size][]; 177 double[] newWeights = new double[size]; 178 for (int i = 0; i < newFeatures.length; i++) { 179 newFeatures[i] = features[i]; 180 newLabels[i] = labels[i]; 181 newWeights[i] = weights[i]; 182 } 183 // Shuffle array 184 for (int i = size; i > 1; i--) { 185 int j = rng.nextInt(i); 186 //swap features 187 SparseVector[] tmpFeature = newFeatures[i-1]; 188 newFeatures[i-1] = newFeatures[j]; 189 newFeatures[j] = tmpFeature; 190 //swap labels 191 int[] tmpLabel = newLabels[i-1]; 192 newLabels[i-1] = newLabels[j]; 193 newLabels[j] = tmpLabel; 194 //swap weights 195 double tmpWeight = newWeights[i-1]; 196 newWeights[i-1] = newWeights[j]; 197 newWeights[j] = tmpWeight; 198 } 199 return new SequenceExampleArray(newFeatures,newLabels,newWeights); 200 } 201 202 /** 203 * A nominal tuple. One day it'll be a record, but not today. 204 */ 205 public static class SequenceExampleArray { 206 public final SparseVector[][] features; 207 public final int[][] labels; 208 public final double[] weights; 209 210 public SequenceExampleArray(SparseVector[][] features, int[][] labels, double[] weights) { 211 this.features = features; 212 this.labels = labels; 213 this.weights = weights; 214 } 215 } 216}