Javaでコレクションから指定した数の要素を無作為に非復元抽出
JavaでSetやListから指定した数の要素をランダムに取り出したいとき用。
Pythonのrandomライブラリを参考にした。
非復元抽出というのは、取り出した要素を元に戻さずに次の要素を取り出す、みたいなやり方。
引数として渡したコレクションへの変更は行わない。
以下ソース
import java.util.Collection; import java.util.List; import java.util.ArrayList; import java.util.Random; public class RandomUtils { /** * 受け取ったコレクションから指定された数の要素をランダムに非復元抽出する。 * * @param population 要素の母集団 * @param n 抽出する要素数 * @param random 乱数生成器 * @return 抽出された要素の集合 */ public static <E> List<E> sample(Collection<E> population, int n, Random random) { int popSize = population.size(); if (popSize < n) { throw new IllegalArgumentException("抽出する要素数(" + n + ")が母集団の数(" + popSize + ")を超えています。"); } List<E> result = new ArrayList<E>(n); List<E> copied = new ArrayList<E>(population); for (int i = 0; i < n; ++i) { int j = random.nextInt(popSize - i); result.add(copied.get(j)); copied.set(j, copied.get(popSize - i - 1)); // 抽出済み要素のインデックスを未抽出要素で埋める } return result; } /** * 受け取ったコレクションから指定された数の要素をランダムに非復元抽出する。 * * @param population 要素の母集団 * @param n 抽出する要素数 * @return 抽出された要素の集合 */ public static <E> List<E> sample(Collection<E> population, int n) { int popSize = population.size(); if (popSize < n) { throw new IllegalArgumentException("抽出する要素数(" + n + ")が母集団の数(" + popSize + ")を超えています。"); } List<E> result = new ArrayList<E>(n); List<E> copied = new ArrayList<E>(population); for (int i = 0; i < n; ++i) { int j = (int) (Math.random() * (popSize - i)); result.add(copied.get(j)); copied.set(j, copied.get(popSize - i - 1)); } return result; } public static void main(String[] args) { int popSize = 10; List<Integer> population = new ArrayList<Integer>(popSize); for (int i = 0; i < popSize; ++i) { population.add(i); } System.out.println(RandomUtils.sample(population, 5)); // 5個抽出 System.out.println(RandomUtils.sample(population, 3, new Random())); // 乱数生成器を与えて3個抽出 RandomUtils.sample(population, 11); // 母集団より多く抽出 } }
実行結果
[1, 9, 4, 5, 7] [8, 5, 9] Exception in thread "main" java.lang.IllegalArgumentException: 抽出する要素数(11)が母集団の数(10)を超えています。 at RandomUtils.sample(RandomUtils.java:54) at RandomUtils.main(RandomUtils.java:75)