Source code for hypercontagion.utils.utilities

"""
Contains useful classes and functions for use in the hypercontagion library.
"""

import heapq
import random
from collections import Counter, defaultdict

import numpy as np

__all__ = [
    "EventQueue",
    "MockSamplableSet",
    "SamplingDict",
    "_process_trans_SIR_",
    "_process_rec_SIR_",
    "_process_trans_SIS_",
    "_process_rec_SIS_",
]


class EventQueue:
    r"""
    This class is used to store and act on a priority queue of events for
    event-driven simulations.  It is based on heapq.
    Each queue is given a tmax (default is infinity) so that any event at later
    time is ignored.

    This is a priority queue of 4-tuples of the form
                   ``(t, counter, function, function_arguments)``
    The ``'counter'`` is present just to break ties, which generally only occur when
    multiple events are put in place for the initial condition, but could also
    occur in cases where events tend to happen at discrete times.
    note that the function is understood to have its first argument be t, and
    the tuple ``function_arguments`` does not include this first t.
    So function is called as
        ``function(t, *function_arguments)``
    Previously I used a class of events, but sorting using the __lt__ function
    I wrote was significantly slower than simply using tuples.
    """

    def __init__(self, tmax=float("Inf")):
        self._Q_ = []
        self.tmax = tmax
        self.counter = 0  # tie-breaker for putting things in priority queue

    def add(self, time, function, args=()):
        """Add event to the queue.

        Parameters
        ----------
        time : float
            time of the event
        function : function name
            name of the function to run when the event is popped.
        args : keyword args
            args of the function excluding time.
        """
        if time < self.tmax:
            heapq.heappush(self._Q_, (time, self.counter, function, args))
            self.counter += 1

    def pop_and_run(self):
        """Pops the next event off the queue and performs the function"""
        t, counter, function, args = heapq.heappop(self._Q_)
        function(t, *args)

    def __len__(self):
        """this allows us to use commands like ``while Q:``

        Returns
        -------
        int
            number of events currently in the queue
        """
        return len(self._Q_)


class SamplingDict:
    """
    The Gillespie algorithm will involve a step that samples a random element
    from a set based on its weight.  This is awkward in Python.

    So I'm introducing a new class based on a stack overflow answer by
    Amber (http://stackoverflow.com/users/148870/amber)
    for a question by
    tba (http://stackoverflow.com/users/46521/tba)
    found at
    http://stackoverflow.com/a/15993515/2966723

    This will allow selecting a random element uniformly, and then use
    rejection sampling to make sure it's been selected with the appropriate
    weight.
    """

    def __init__(self, weighted=False):
        self.item_to_position = {}
        self.items = []

        self.weighted = weighted
        if self.weighted:
            self.weight = defaultdict(int)  # presume all weights positive
            self.max_weight = 0
            self._total_weight = 0
            self.max_weight_count = 0

    def __len__(self):
        """Number of items in the dict

        Returns
        -------
        int
            number of items in dict
        """
        return len(self.items)

    def __contains__(self, item):
        """Whether an item exists in the dictionary"""
        return item in self.item_to_position

    def _update_max_weight(self):
        """Internal function to help with the rejection sampling"""
        C = Counter(
            self.weight.values()
        )  # may be a faster way to do this, we only need to count the max.
        self.max_weight = max(C.keys())
        self.max_weight_count = C[self.max_weight]

    def insert(self, item, weight=None):
        """insert an item into the sampling dict

        Parameters
        ----------
        item : hashable
            the ID of the item
        weight : float, default: None
            the weight of the item, if None, unweighted.

        Notes
        -----
        If not present, then inserts the thing (with weight if appropriate)
        if already there, replaces the weight unless weight is 0

        If weight is 0, then it removes the item and doesn't replace.

        replaces weight if already present, does not increment weight.
        """

        if self.__contains__(item):
            self.remove(item)
        if weight != 0:
            self.update(item, weight_increment=weight)

    def update(self, item, weight_increment=None):
        """_summary_

        Parameters
        ----------
        item : hashable
            ID of the item
        weight_increment : float, default: None
            how much to increment the weight if weighted.

        Raises
        ------
        Exception
            if weighted and no weight increment specified.

        Notes
        -----
        If not present, then inserts the item (with weight if appropriate)
        if already there, increments weight

        increments weight if already present, cannot overwrite weight.
        """
        if (
            weight_increment is not None
        ):  # will break if passing a weight to unweighted case
            if weight_increment > 0 or self.weight[item] != self.max_weight:
                self.weight[item] = self.weight[item] + weight_increment
                self._total_weight += weight_increment
                if self.weight[item] > self.max_weight:
                    self.max_weight_count = 1
                    self.max_weight = self.weight[item]
                elif self.weight[item] == self.max_weight:
                    self.max_weight_count += 1
            else:  # it's a negative increment and was at max
                self.max_weight_count -= 1
                self.weight[item] = self.weight[item] + weight_increment
                self._total_weight += weight_increment
                self.max_weight_count -= 1
                if self.max_weight_count == 0:
                    self._update_max_weight
        elif self.weighted:
            raise Exception("if weighted, must assign weight_increment")

        if item in self:  # we've already got it, do nothing else
            return
        self.items.append(item)
        self.item_to_position[item] = len(self.items) - 1

    def remove(self, choice):
        """Remove item and update weights

        Parameters
        ----------
        choice : hashable
            item ID
        """
        position = self.item_to_position.pop(
            choice
        )  # why don't we pop off the last item and put it in the choice index?
        last_item = self.items.pop()
        if position != len(self.items):
            self.items[position] = last_item
            self.item_to_position[last_item] = position

        if self.weighted:
            weight = self.weight.pop(choice)
            self._total_weight -= weight
            if weight == self.max_weight:
                # if we find ourselves in this case often
                # it may be better just to let max_weight be the
                # largest weight *ever* encountered, even if all remaining weights are less
                #
                self.max_weight_count -= 1
                if self.max_weight_count == 0 and len(self) > 0:
                    self._update_max_weight()

    def choose_random(self):
        """chooses a random node.  If there is a weight, it will use rejection
        sampling to choose a random node until it succeeds"""
        if self.weighted:
            while True:
                choice = random.choice(self.items)
                if random.random() < self.weight[choice] / self.max_weight:
                    break
            return choice

        else:
            return random.choice(self.items)

    def random_removal(self):
        """uses other class methods to choose and then remove a random item"""
        choice = self.choose_random()
        self.remove(choice)
        return choice

    def total_weight(self):
        """Get the sum of all the weights in the dict."""
        if self.weighted:
            return self._total_weight
        else:
            return len(self)

    def update_total_weight(self):
        """Update the sum of weights."""
        self._total_weight = sum(self.weight[item] for item in self.items)


def choice(arr, p):
    """
    Returns a random element from ``arr`` with probability given in array ``p``.
    If ``arr`` is not an iterable, the function returns the index of the chosen element.
    """
    ndx = np.argmax(np.random.rand() < np.cumsum(p))
    try:
        return arr[ndx]
    except TypeError as e:
        return ndx


class MockSamplableSet:
    """
    A set of items that can be sampled with probability
    proportional to a corresponding item weight.

    Mimicks the behavior of github.com/gstonge/SamplableSet
    without being as efficient.

    Works similar to Python's set, with ``__getitem__``,
    ``__setitem__``, ``__delitem__``, ``__iter__``,
    ``__len__``, ``__contains__``.

    Parameters
    ==========
    min_weight : float
        minimum possible weight
    max_weight : float
        maximum possible weight
    weighted_elements : list, default = []
        list of 2-tuples, first entry an item, second entry a weight
    cpp_type : str, default = 'int'
        The type of the items.

    Attributes
    ==========
    min_weight : float
        minimum possible weight
    max_weight : float
        maximum possible weight
    items : numpy.ndarray
        list of items in this set
    weights : numpy.ndarray
        list of corresponding weights
    """

    def __init__(self, min_weight, max_weight, weighted_elements=[], cpp_type="int"):

        self.min_weight = min_weight
        self.max_weight = max_weight

        if type(weighted_elements) == dict:
            weighted_elements = list(weighted_elements.items())

        self.items = np.array([e[0] for e in weighted_elements], dtype=cpp_type)
        self.weights = np.array([e[1] for e in weighted_elements], dtype=float)
        sort_ndx = np.argsort(self.items)
        self.items = self.items[sort_ndx]
        self.weights = self.weights[sort_ndx]
        self._total_weight = self.weights.sum()

        if np.any(self.weights < self.min_weight):
            raise ValueError("There are weights below the limit.")

        if np.any(self.weights > self.max_weight):
            raise ValueError("There are weights above the limit.")

    def sample(self):
        """
        Random sample from the set, sampled
        with probability proportional to items' weight.

        Returns
        =======
        item : cpp_type
            An item from the set
        weight : float
            The weight of the item
        """

        # ndx = np.random.choice(len(self.items),p=self.weights/self._total_weight)
        # ndx = np.argwhere(np.random.rand()<np.cumsum(self.weights/self._total_weight))[0][0]
        ndx = choice(len(self.items), p=self.weights / self._total_weight)
        return self.items[ndx], self.weights[ndx]

    def __getitem__(self, key):
        found_key, ndx = self._find_key(key)
        if not found_key:
            raise KeyError("`", key, "` is not in this set.")
        else:
            return self.weights[ndx]

    def __delitem__(self, key):
        found_key, ndx = self._find_key(key)
        if found_key:
            self.items = np.delete(self.items, ndx)
            self.weights = np.delete(self.weights, ndx)
            self._total_weight = self.weights.sum()

    def __setitem__(self, key, value):
        if value < self.min_weight or value > self.max_weight:
            raise ValueError(
                "Inserting element-weight pair "
                + str(key)
                + " "
                + str(value)
                + " \n"
                + "has weight value out of bounds of "
                + str(self.min_weight)
                + " "
                + str(self.max_weight)
            )
        found_key, ndx = self._find_key(key)
        if not found_key:
            self.items = np.insert(self.items, ndx, key)
            self.weights = np.insert(self.weights, ndx, value)
        else:
            self.weights[ndx] = value

        self._total_weight = self.weights.sum()

    def _find_key(self, key):
        ndx = np.searchsorted(self.items, key)
        return (not ((ndx == len(self.items) or self.items[ndx] != key)), ndx)

    def __iter__(self):
        self._ndx = 0
        return self

    def __next__(self):
        if self._ndx < len(self.items):
            i, w = self.items[self._ndx], self.weights[self._ndx]
            self._ndx += 1
            return (i, w)
        else:
            raise StopIteration

    def __len__(self):
        return len(self.items)

    def __contains__(self, key):
        return self._find_key(key)[0]

    def total_weight(self):
        """Obtain the total weight of the set"""
        return self._total_weight

    def clear(self):
        """Reset the set. Not implemented yet."""
        pass


[docs]def _process_trans_SIR_( t, times, S, I, R, Q, H, status, transmission_function, gamma, tau, source, target, rec_time, pred_inf_time, events, ): if status[target] == "S": # nothing happens if already infected. status[target] = "I" times.append(t) events.append( { "time": t, "source": source, "target": target, "old_state": "S", "new_state": "I", } ) S.append(S[-1] - 1) # one less susceptible I.append(I[-1] + 1) # one more infected R.append(R[-1]) # no change to recovered rec_time[target] = t + rec_delay(gamma) if rec_time[target] < Q.tmax: Q.add( rec_time[target], _process_rec_SIR_, args=(times, S, I, R, status, target, events), ) for edge_id in H.nodes.memberships(target): edge = H.edges.members(edge_id) for nbr in edge: if status[nbr] == "S": inf_time = t + trans_delay(tau, edge) # create statuses at the time requested temp_status = defaultdict(lambda: "R") for node in edge: if status[node] == "I" and rec_time[node] > inf_time: temp_status[node] = "I" elif status[node] == "S": temp_status[node] = "S" contagion = transmission_function(nbr, temp_status, edge) if contagion != 0 and inf_time < pred_inf_time[nbr]: Q.add( inf_time, _process_trans_SIR_, args=( times, S, I, R, Q, H, status, transmission_function, gamma, tau, edge_id, nbr, rec_time, pred_inf_time, events, ), ) pred_inf_time[nbr] = inf_time
[docs]def _process_rec_SIR_(t, times, S, I, R, status, node, events): times.append(t) events.append( {"time": t, "source": None, "target": node, "old_state": "I", "new_state": "R"} ) S.append(S[-1]) # no change to number susceptible I.append(I[-1] - 1) # one less infected R.append(R[-1] + 1) # one more recovered status[node] = "R"
[docs]def _process_trans_SIS_( t, times, S, I, Q, H, status, transmission_function, gamma, tau, source, target, rec_time, pred_inf_time, events, ): if status[target] == "S": status[target] = "I" events.append( { "time": t, "source": source, "target": target, "old_state": "S", "new_state": "I", } ) I.append(I[-1] + 1) # one more infected S.append(S[-1] - 1) # one less susceptible times.append(t) rec_time[target] = t + rec_delay(gamma) if rec_time[target] < Q.tmax: Q.add( rec_time[target], _process_rec_SIS_, args=(times, S, I, status, target, events), ) for edge_id in H.nodes.memberships(target): edge = H.edges.members(edge_id) for nbr in edge: if status[nbr] == "S": inf_time = t + trans_delay(tau, edge) # create statuses at the time requested temp_status = defaultdict(lambda: "S") for node in edge: if status[node] == "I" and rec_time[node] >= inf_time: temp_status[node] = "I" contagion = transmission_function(nbr, temp_status, edge) if contagion != 0 and inf_time < pred_inf_time[nbr]: Q.add( inf_time, _process_trans_SIS_, args=( times, S, I, Q, H, status, transmission_function, gamma, tau, edge_id, nbr, rec_time, pred_inf_time, events, ), ) pred_inf_time[nbr] = inf_time
[docs]def _process_rec_SIS_(t, times, S, I, status, node, events): times.append(t) events.append( {"time": t, "source": None, "target": node, "old_state": "I", "new_state": "S"} ) S.append(S[-1] + 1) # one more susceptible I.append(I[-1] - 1) # one less infected status[node] = "S"
def rec_delay(rate): try: return random.expovariate(rate) except: return float("Inf") def trans_delay(tau, edge): try: return random.expovariate(tau[len(edge)]) except ZeroDivisionError: return np.inf