Source code for skoot.utils.iterables

# -*- coding: utf-8 -*-
#
# Author: Taylor Smith <taylor.smith@alkaline-ml.com>

import six
import types

from .compat import xrange

__all__ = [
    'chunk',
    'ensure_iterable',
    'flatten_all',
    'is_iterable'
]


[docs]def ensure_iterable(element): """Make an element an iterable. If an element is already iterable, return it as is. If it's not, return it inside of a list. This helper function allows us to avoid clunky if/then checks all over the place:: if not is_iterable(this): this = [this] Parameters ---------- element : object An iterable or not """ if not is_iterable(element): element = [element] return element
[docs]def flatten_all(container): """Recursively flattens an arbitrarily nested iterable. Parameters ---------- container : array_like, shape=(n_items,) The iterable to flatten. If the ``container`` is not iterable, it will be returned in a list as ``[container]`` Examples -------- The example below produces a list of mixed results: >>> a = [[[], 3, 4],['1', 'a'],[[[1]]], 1, 2] >>> list(flatten_all(a)) [3, 4, '1', 'a', 1, 1, 2] Returns ------- res : generator A generator of all of the flattened values. """ if not is_iterable(container): yield container else: for i in container: if is_iterable(i): for j in flatten_all(i): yield j else: yield i
[docs]def is_iterable(x): """Determine whether an element is iterable. This function determines whether an element is iterable by checking for the ``__iter__`` attribute. Since Python 3.x adds the ``__iter__`` attribute to strings, we also have to make sure the input is not a string or unicode type. Parameters ---------- x : object The object or primitive to test whether or not is an iterable. """ if isinstance(x, six.string_types): return False return hasattr(x, '__iter__')
[docs]def chunk(v, n): """Chunk a vector into k roughly equal parts. Parameters ---------- v : array-like, shape=(n_samples,) The vector of values. n : int The number of chunks to produce. """ # if v is a generator, we need it as a list... if isinstance(v, types.GeneratorType): v = list(v) len_v = len(v) # fail out if n > len_v if n > len_v: raise ValueError("N exceeds length of vector!") k, m = divmod(len_v, n) for i in xrange(n): yield v[i * k + min(i, m): (i + 1) * k + min(i + 1, m)]