Source code for odatse.util.resampling

# SPDX-License-Identifier: MPL-2.0
#
# ODAT-SE -- an open framework for data analysis
# Copyright (C) 2020- The University of Tokyo
#
# This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
# If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

import typing
from typing import Union, List, Iterable

import abc

import collections
import itertools
import numpy as np


[docs] class Resampler(abc.ABC): @abc.abstractmethod def reset(self, weights: Iterable): ... @abc.abstractmethod def sample(self, rs: np.random.RandomState, size=None) -> Union[int, np.ndarray]: ...
[docs] class BinarySearch(Resampler): """ A resampler that uses binary search to sample based on given weights. """ weights_accumulated: List[float] wmax: float
[docs] def __init__(self, weights: Iterable): """ Initialize the BinarySearch resampler with the given weights. Parameters ---------- weights : Iterable An iterable of weights. """ self.reset(weights)
[docs] def reset(self, weights: Iterable): """ Reset the resampler with new weights. Parameters ---------- weights : Iterable An iterable of weights. """ self.weights_accumulated = list(itertools.accumulate(weights)) self.wmax = self.weights_accumulated[-1]
@typing.overload def sample(self, rs: np.random.RandomState) -> int: """ Sample a single index based on the weights. Parameters ---------- rs : np.random.RandomState A random state for generating random numbers. Returns ------- int A single sampled index. """ ... @typing.overload def sample(self, rs: np.random.RandomState, size) -> np.ndarray: """ Sample multiple indices based on the weights. Parameters ---------- rs : np.random.RandomState A random state for generating random numbers. size : The number of samples to generate. Returns ------- np.ndarray An array of sampled indices. """ ...
[docs] def sample(self, rs: np.random.RandomState, size=None) -> Union[int, np.ndarray]: """ Sample indices based on the weights. Parameters ---------- rs : np.random.RandomState A random state for generating random numbers. size : The number of samples to generate. If None, a single sample is generated. Returns ------- int or np.ndarray A single sampled index or an array of sampled indices. """ if size is None: return self._sample(self.wmax * rs.rand()) else: return np.array([self._sample(r) for r in self.wmax * rs.rand(size)])
[docs] def _sample(self, r: float) -> int: """ Perform a binary search to find the index corresponding to the given random number. Parameters ---------- r : float A random number scaled by the maximum weight. Returns ------- int The index corresponding to the random number. """ return typing.cast(int, np.searchsorted(self.weights_accumulated, r))
[docs] class WalkerTable(Resampler): """ A resampler that uses Walker's alias method to sample based on given weights. """ N: int itable: np.ndarray ptable: np.ndarray
[docs] def __init__(self, weights: Iterable): """ Initialize the WalkerTable resampler with the given weights. Parameters ---------- weights : Iterable An iterable of weights. """ self.reset(weights)
[docs] def reset(self, weights: Iterable): """ Reset the resampler with new weights. Parameters ---------- weights : Iterable An iterable of weights. """ self.ptable = np.array(weights).astype(np.float64).flatten() self.N = len(self.ptable) self.itable = np.full(self.N, -1) mean = self.ptable.mean() self.ptable -= mean shorter = collections.deque([i for i, p in enumerate(self.ptable) if p < 0.0]) longer = collections.deque([i for i, p in enumerate(self.ptable) if p >= 0.0]) while len(longer) > 0 and len(shorter) > 0: ilong = longer[0] ishort = shorter.popleft() self.itable[ishort] = ilong self.ptable[ilong] += self.ptable[ishort] if self.ptable[ilong] <= 0.0: longer.popleft() shorter.append(ilong) self.ptable += mean self.ptable *= 1.0 / mean
@typing.overload def sample(self, rs: np.random.RandomState) -> int: """ Sample a single index based on the weights. Parameters ---------- rs : np.random.RandomState A random state for generating random numbers. Returns ------- int A single sampled index. """ ... @typing.overload def sample(self, rs: np.random.RandomState, size) -> np.ndarray: """ Sample multiple indices based on the weights. Parameters ---------- rs : np.random.RandomState A random state for generating random numbers. size : The number of samples to generate. Returns ------- np.ndarray An array of sampled indices. """ ...
[docs] def sample(self, rs: np.random.RandomState, size=None) -> Union[int, np.ndarray]: """ Sample indices based on the weights. Parameters ---------- rs : np.random.RandomState A random state for generating random numbers. size : The number of samples to generate. If None, a single sample is generated. Returns ------- int or np.ndarray A single sampled index or an array of sampled indices. """ if size is None: r = rs.rand() * self.N return self._sample(r) else: r = rs.rand(size) * self.N i = np.floor(r).astype(np.int64) p = r - i ret = np.where(p < self.ptable[i], i, self.itable[i]) return ret
[docs] def _sample(self, r: float) -> int: """ Perform a sampling operation based on the given random number. Parameters ---------- r : float A random number scaled by the number of weights. Returns ------- int The index corresponding to the random number. """ i = int(np.floor(r)) p = r - i if p < self.ptable[i]: return i else: return self.itable[i]
if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(prog="resampling test") parser.add_argument( "-s", "--seed", type=int, default=12345, help="random number seed" ) parser.add_argument("-m", "--method", default="walker", help="method to resample") parser.add_argument("-N", type=int, default=100000, help="Number of samples") args = parser.parse_args() rs = np.random.RandomState(args.seed) ps = [0.0, 1.0, 2.0, 3.0] # ps = rs.rand(5) S = np.sum(ps) if args.method == "walker": resampler = WalkerTable(ps) else: resampler = BinarySearch(ps) samples = resampler.sample(rs, args.N) print("#i result exact diff") for i, p in enumerate(ps): r = np.count_nonzero(samples == i) / args.N print(f"{i} {r} {p/S} {r - p/S}")