001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sgd;
018
019import org.tribuo.math.la.SparseVector;
020
021import java.util.SplittableRandom;
022
023/**
024 * SGD utilities. Currently stores methods for shuffling examples and their associated labels and weights.
025 */
026public class Util {
027    /**
028     * In place shuffle of the features, labels and weights.
029     * @param features Input features.
030     * @param labels Input labels.
031     * @param weights Input weights.
032     * @param rng SplittableRandom number generator.
033     */
034    public static void shuffleInPlace(SparseVector[] features, int[] labels, double[] weights, SplittableRandom rng) {
035        int size = features.length;
036        // Shuffle array
037        for (int i = size; i > 1; i--) {
038            int j = rng.nextInt(i);
039            //swap features
040            SparseVector tmpFeature = features[i-1];
041            features[i-1] = features[j];
042            features[j] = tmpFeature;
043            //swap labels
044            int tmpLabel = labels[i-1];
045            labels[i-1] = labels[j];
046            labels[j] = tmpLabel;
047            //swap weights
048            double tmpWeight = weights[i-1];
049            weights[i-1] = weights[j];
050            weights[j] = tmpWeight;
051        }
052    }
053
054    /**
055     * In place shuffle of the features, labels, weights and indices.
056     * @param features Input features.
057     * @param labels Input labels.
058     * @param weights Input weights.
059     * @param indices Input indices.
060     * @param rng SplittableRandom number generator.
061     */
062    public static void shuffleInPlace(SparseVector[] features, int[] labels, double[] weights, int[] indices, SplittableRandom rng) {
063        int size = features.length;
064        // Shuffle array
065        for (int i = size; i > 1; i--) {
066            int j = rng.nextInt(i);
067            //swap features
068            SparseVector tmpFeature = features[i-1];
069            features[i-1] = features[j];
070            features[j] = tmpFeature;
071            //swap labels
072            int tmpLabel = labels[i-1];
073            labels[i-1] = labels[j];
074            labels[j] = tmpLabel;
075            //swap weights
076            double tmpWeight = weights[i-1];
077            weights[i-1] = weights[j];
078            weights[j] = tmpWeight;
079            //swap indices
080            int tmpIndex = indices[i-1];
081            indices[i-1] = indices[j];
082            indices[j] = tmpIndex;
083        }
084    }
085
086    /**
087     * Shuffles the features, labels and weights returning a tuple of the shuffled inputs.
088     * @param features Input features.
089     * @param labels Input labels.
090     * @param weights Input weights.
091     * @param rng SplittableRandom number generator.
092     * @return A tuple of shuffled features, labels and weights.
093     */
094    public static ExampleArray shuffle(SparseVector[] features, int[] labels, double[] weights, SplittableRandom rng) {
095        int size = features.length;
096        SparseVector[] newFeatures = new SparseVector[size];
097        int[] newLabels = new int[size];
098        double[] newWeights = new double[size];
099        for (int i = 0; i < newFeatures.length; i++) {
100            newFeatures[i] = features[i];
101            newLabels[i] = labels[i];
102            newWeights[i] = weights[i];
103        }
104        // Shuffle array
105        for (int i = size; i > 1; i--) {
106            int j = rng.nextInt(i);
107            //swap features
108            SparseVector tmpFeature = newFeatures[i-1];
109            newFeatures[i-1] = newFeatures[j];
110            newFeatures[j] = tmpFeature;
111            //swap labels
112            int tmpLabel = newLabels[i-1];
113            newLabels[i-1] = newLabels[j];
114            newLabels[j] = tmpLabel;
115            //swap weights
116            double tmpWeight = newWeights[i-1];
117            newWeights[i-1] = newWeights[j];
118            newWeights[j] = tmpWeight;
119        }
120        return new ExampleArray(newFeatures,newLabels,newWeights);
121    }
122
123    /**
124     * A nominal tuple. One day it'll be a record, but not today.
125     */
126    public static class ExampleArray {
127        public final SparseVector[] features;
128        public final int[] labels;
129        public final double[] weights;
130
131        public ExampleArray(SparseVector[] features, int[] labels, double[] weights) {
132            this.features = features;
133            this.labels = labels;
134            this.weights = weights;
135        }
136    }
137
138    /**
139     * In place shuffle used for sequence problems.
140     * @param features Input features.
141     * @param labels Input labels.
142     * @param weights Input weights.
143     * @param rng SplittableRandom number generator.
144     */
145    public static void shuffleInPlace(SparseVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) {
146        int size = features.length;
147        // Shuffle array
148        for (int i = size; i > 1; i--) {
149            int j = rng.nextInt(i);
150            //swap features
151            SparseVector[] tmpFeature = features[i-1];
152            features[i-1] = features[j];
153            features[j] = tmpFeature;
154            //swap labels
155            int[] tmpLabel = labels[i-1];
156            labels[i-1] = labels[j];
157            labels[j] = tmpLabel;
158            //swap weights
159            double tmpWeight = weights[i-1];
160            weights[i-1] = weights[j];
161            weights[j] = tmpWeight;
162        }
163    }
164
165    /**
166     * Shuffles a sequence of features, labels and weights, returning a tuple of the shuffled values.
167     * @param features Input features.
168     * @param labels Input labels.
169     * @param weights Input weights.
170     * @param rng SplittableRandom number generator.
171     * @return A tuple of shuffled features, labels and weights.
172     */
173    public static SequenceExampleArray shuffle(SparseVector[][] features, int[][] labels, double[] weights, SplittableRandom rng) {
174        int size = features.length;
175        SparseVector[][] newFeatures = new SparseVector[size][];
176        int[][] newLabels = new int[size][];
177        double[] newWeights = new double[size];
178        for (int i = 0; i < newFeatures.length; i++) {
179            newFeatures[i] = features[i];
180            newLabels[i] = labels[i];
181            newWeights[i] = weights[i];
182        }
183        // Shuffle array
184        for (int i = size; i > 1; i--) {
185            int j = rng.nextInt(i);
186            //swap features
187            SparseVector[] tmpFeature = newFeatures[i-1];
188            newFeatures[i-1] = newFeatures[j];
189            newFeatures[j] = tmpFeature;
190            //swap labels
191            int[] tmpLabel = newLabels[i-1];
192            newLabels[i-1] = newLabels[j];
193            newLabels[j] = tmpLabel;
194            //swap weights
195            double tmpWeight = newWeights[i-1];
196            newWeights[i-1] = newWeights[j];
197            newWeights[j] = tmpWeight;
198        }
199        return new SequenceExampleArray(newFeatures,newLabels,newWeights);
200    }
201
202    /**
203     * A nominal tuple. One day it'll be a record, but not today.
204     */
205    public static class SequenceExampleArray {
206        public final SparseVector[][] features;
207        public final int[][] labels;
208        public final double[] weights;
209
210        public SequenceExampleArray(SparseVector[][] features, int[][] labels, double[] weights) {
211            this.features = features;
212            this.labels = labels;
213            this.weights = weights;
214        }
215    }
216}