001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.sgd; 018 019import org.tribuo.math.la.DenseVector; 020import org.tribuo.math.la.SparseVector; 021 022import java.util.SplittableRandom; 023 024/** 025 * Utilities. Currently stores methods for shuffling examples and their associated regression dimensions and weights. 026 */ 027public class Util { 028 /** 029 * In place shuffle of the features, labels and weights. 030 * @param features Input features. 031 * @param regressors Input regressors. 032 * @param weights Input weights. 033 * @param rng SplittableRandom number generator. 034 */ 035 public static void shuffleInPlace(SparseVector[] features, DenseVector[] regressors, double[] weights, SplittableRandom rng) { 036 int size = features.length; 037 // Shuffle array 038 for (int i = size; i > 1; i--) { 039 int j = rng.nextInt(i); 040 //swap features 041 SparseVector tmpFeature = features[i-1]; 042 features[i-1] = features[j]; 043 features[j] = tmpFeature; 044 //swap regressors 045 DenseVector tmpRegressors = regressors[i-1]; 046 regressors[i-1] = regressors[j]; 047 regressors[j] = tmpRegressors; 048 //swap weights 049 double tmpWeight = weights[i-1]; 050 weights[i-1] = weights[j]; 051 weights[j] = tmpWeight; 052 } 053 } 054 055 /** 056 * In place shuffle of the features, labels and weights. 057 * @param features Input features. 058 * @param regressors Input regressors. 059 * @param weights Input weights. 060 * @param indices Input indices. 061 * @param rng SplittableRandom number generator. 062 */ 063 public static void shuffleInPlace(SparseVector[] features, DenseVector[] regressors, double[] weights, int[] indices, SplittableRandom rng) { 064 int size = features.length; 065 // Shuffle array 066 for (int i = size; i > 1; i--) { 067 int j = rng.nextInt(i); 068 //swap features 069 SparseVector tmpFeature = features[i-1]; 070 features[i-1] = features[j]; 071 features[j] = tmpFeature; 072 //swap regressors 073 DenseVector tmpLabel = regressors[i-1]; 074 regressors[i-1] = regressors[j]; 075 regressors[j] = tmpLabel; 076 //swap weights 077 double tmpWeight = weights[i-1]; 078 weights[i-1] = weights[j]; 079 weights[j] = tmpWeight; 080 //swap indices 081 int tmpIndex = indices[i-1]; 082 indices[i-1] = indices[j]; 083 indices[j] = tmpIndex; 084 } 085 } 086}