Skip to content

mis.shared.graphs

[docs] module mis.shared.graphs

from __future__ import annotations
from abc import ABC, abstractmethod
import typing

if typing.TYPE_CHECKING:
    from .types import Weighting

import networkx as nx


class BaseWeightPicker(ABC):
    """
    Utility class to pick the weight of a node.

    Unweighted implementations optimize the methods into trivial
    operations.
    """

    @classmethod
    @abstractmethod
    def node_weight(cls, graph: nx.Graph, node: int) -> float:
        """
        Get the weight of a node.

        For a weighted cost picker, this returns attribute `weight` of the node,
        or 1. if the node doesn't specify a `weight`.

        For an unweighted cost picker, this always returns 1.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def set_node_weight(cls, graph: nx.Graph, node: int, weight: float) -> None:
        """
        Set the weight of a node.

        For a weighted cost picker, this returns attribute `weight` of the node,
        or 1. if the node doesn't specify a `weight`.

        For an unweighted cost picker, raise an error.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def node_delta(cls, graph: nx.Graph, node: int, delta: float) -> float:
        """
        Apply a delta to the weight of a node.

        Raises an error in an unweighted cost picker.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def subgraph_weight(cls, graph: nx.Graph, nodes: typing.Iterable[int]) -> float:
        """
        Get the weight of a subraph.

        See `node_weight` for the definition of weight.

        For an unweighted cost picker, this always returns `len(nodes)`.
        """
        raise NotImplementedError

    @classmethod
    def for_weighting(cls, weighting: Weighting) -> type[BaseWeightPicker]:
        """
        Pick a cost picker for an objective.
        """
        from .types import Weighting

        if weighting == Weighting.UNWEIGHTED:
            return UnweightedPicker
        elif weighting == Weighting.WEIGHTED:
            return WeightedPicker


class WeightedPicker(BaseWeightPicker):
    @classmethod
    def node_weight(cls, graph: nx.Graph, node: int) -> float:
        result = graph.nodes[node].get("weight", 1.0)
        assert isinstance(result, float)
        return result

    @classmethod
    def set_node_weight(cls, graph: nx.Graph, node: int, weight: float) -> None:
        graph.nodes[node]["weight"] = weight

    @classmethod
    def subgraph_weight(cls, graph: nx.Graph, nodes: typing.Iterable[int]) -> float:
        return float(sum(cls.node_weight(graph, n) for n in nodes))


class UnweightedPicker(BaseWeightPicker):
    @classmethod
    def node_weight(cls, graph: nx.Graph, node: int) -> float:
        return 1.0

    @classmethod
    def set_node_weight(cls, graph: nx.Graph, node: int, weight: float) -> None:
        raise NotImplementedError("UnweightedPicker does not support operation `set_node_weight`")

    @classmethod
    def subgraph_weight(cls, graph: nx.Graph, nodes: typing.Iterable[int]) -> float:
        # In the unweighted picker, we can usually run this function in constant time.
        if hasattr(nodes, "__len__"):
            # Usually, we call this with a list or a set, so `len()` is fast.
            sized = typing.cast(typing.Sized, nodes)
            return float(len(sized))
        # Otherwise, we have to count the number of nodes.
        # Apparently, constructor `tuple` is optimized for this, and clearly faster
        # than calling `node_weight` for each node.
        return float(len(tuple(nodes)))


def closed_neighborhood(graph: nx.Graph, node: int) -> list[int]:
    """
    Return the list of closed neighbours of a node.
    """
    neighbours = list(graph.neighbors(node))
    neighbours.append(node)
    return neighbours


def is_independent(graph: nx.Graph, nodes: list[int]) -> bool:
    """
    Checks if the node set is an independent set (no edges between them).

    Args:
        graph: The graph to check.
        nodes: The set of nodes.

    Returns:
        True if independent, False otherwise.
    """
    return not any(graph.has_edge(u, v) for i, u in enumerate(nodes) for v in nodes[i + 1 :])


def remove_neighborhood(graph: nx.Graph, nodes: list[int]) -> nx.Graph:
    """
    Removes a node and all its neighbors from the graph.

    Args:
        graph: The graph to modify.
        nodes: List of nodes to remove.

    Returns:
        The reduced graph.
    """
    reduced = graph.copy()
    to_remove = set(nodes)
    for node in nodes:
        to_remove.update(graph.neighbors(node))
    reduced.remove_nodes_from(to_remove)
    return reduced