Skip to content

mis.shared.types

[docs] module mis.shared.types

from __future__ import annotations

from enum import Enum
from typing import Any
import networkx
import matplotlib.pyplot as plt


class MethodType(str, Enum):
    """
    The method used to extract the MIS.
    """

    EAGER = "eager"
    """
    An eager solver that attempts to extract a MIS in a single
    shot.
    """

    GREEDY = "greedy"
    """
    A greedy solver that decomposes the graph into smaller subgraphs
    that can benefit from device-specific physical layouts.
    """


class Weighting(str, Enum):
    """
    The algorithm used by the solver.
    """

    UNWEIGHTED = "unweighted"
    """
    Unweighted Maximum Independent Set

    Ignore any weight attached to nodes and attempt to maximize the number
    of nodes in the resulting independent set.

    This algorithm imposes fewer restrictions on the underlying quantum
    device than the weighted algorithm and may call upon faster and more
    benefitial pre/post-processing heuristics.
    """

    WEIGHTED = "weighted"
    """
    Weighted Maximum Independent Set

    Any node in the graph may have a property `weight` (float, defaulting to
    `1.0`) specifying their weight. The algorithm attempts to maximize
    the total weight in the resulting independent set.

    This algorithm may not work on all quantum devices, as it relies upon
    specific hardware capabilities. As of this writing, pre-processing and
    post-processing heuristics are typically slower and less benefitial than
    the unweighted heuristics, with the consequence that execution on a
    device may require more qubits.
    """


class MISInstance:
    def __init__(self, graph: networkx.Graph):
        # FIXME: Make it work with pytorch geometric

        self.original_graph = graph.copy()

        # Our algorithms depend on nodes being consecutive integers, starting
        # from 0, so we first need to rename nodes in the graph.
        self.index_to_node_label: dict[int, Any] = dict()
        self.node_label_to_index: dict[Any, int] = dict()
        for index, label in enumerate(graph.nodes()):
            self.index_to_node_label[index] = label
            self.node_label_to_index[label] = index

        # Copy nodes (and weights), if they exist.
        self.graph = networkx.Graph()
        for index, node in enumerate(graph.nodes()):
            self.graph.add_node(index)
            if "weight" in graph.nodes[node]:
                self.graph.nodes[index]["weight"] = graph.nodes[node]["weight"]

        # Copy edges.
        for u, v in graph.edges():
            index_u = self.node_label_to_index[u]
            index_v = self.node_label_to_index[v]
            self.graph.add_edge(index_u, index_v)

    def draw(
        self,
        nodes: list[int] | None = None,
        node_size: int = 600,
        highlight_color: str = "darkgreen",
        font_family: str = "serif",
    ) -> None:
        """
        Draw instance graph with highlighted nodes.

        Parameters:

            nodes (list[int]): List of nodes to highlight.
            node_size (int): Size of drawn nodes in drawn graph. (default: 600)
            highlight_color (str): Color to highlight nodes with. (default: "darkgreen")
        """
        # Obtain a view of all nodes
        all_nodes = self.original_graph.nodes
        # Compute graph layout
        node_positions = networkx.kamada_kawai_layout(self.original_graph)
        # Keyword dictionaries to customize appearance
        highlighted_node_kwds = {"node_color": highlight_color, "node_size": node_size}
        unhighlighted_node_kwds = {
            "node_color": "white",
            "edgecolors": "black",
            "node_size": node_size,
        }
        if nodes:  # If nodes is not empty
            original_nodes = [self.index_to_node_label[i] for i in nodes]
            nodeset = set(original_nodes)  # Create a set from node list for easier operations
            if not nodeset.issubset(all_nodes):
                invalid_nodes = list(nodeset - all_nodes)
                bad_nodes = "[" + ", ".join([str(node) for node in invalid_nodes[:10]])
                if len(invalid_nodes) > 10:
                    bad_nodes += ", ...]"
                else:
                    bad_nodes += "]"
                if len(invalid_nodes) == 1:
                    raise Exception("node " + bad_nodes + " is not present in the problem instance")
                else:
                    raise Exception(
                        "nodes " + bad_nodes + " are not present in the problem instance"
                    )
            nodes_complement = all_nodes - nodeset
            # Draw highlighted nodes
            networkx.draw_networkx_nodes(
                self.original_graph,
                node_positions,
                nodelist=original_nodes,
                **highlighted_node_kwds,
            )
            # Draw unhighlighted nodes
            networkx.draw_networkx_nodes(
                self.original_graph,
                node_positions,
                nodelist=list(nodes_complement),
                **unhighlighted_node_kwds,
            )
        else:
            networkx.draw_networkx_nodes(
                self.original_graph,
                node_positions,
                nodelist=list(all_nodes),
                **unhighlighted_node_kwds,
            )
        # Draw node labels
        networkx.draw_networkx_labels(self.original_graph, node_positions, font_family=font_family)
        # Draw edges
        networkx.draw_networkx_edges(self.original_graph, node_positions)
        plt.tight_layout()
        plt.axis("off")
        plt.show()

    def node_index(self, node: Any) -> int:
        """
        Return the index for a node in the original graph.
        """
        return self.node_label_to_index[node]

    def node_indices(self, nodes: list[Any]) -> list[int]:
        """
        Return the indices for nodes in the original graph.
        """
        return [self.node_index(node) for node in nodes]


class MISSolution:
    """
    A solution to a MIS problem.

    Attributes:
        instance (MISInstance): The MIS instance to which this class represents a solution.
        size (int): The number of nodes in this solution.
        node_indices (list[int]): The indices of the nodes of `instance` picked in this solution.
        nodes (list[Any]): The nodes of `instance` picked in this solution.
        frequency (float): How often this solution showed up in the measures, where 0. represents
            a solution that never showed up in the meaures and 1. a solution that showed up in all
            measures.
    """

    def __init__(self, instance: MISInstance, nodes: list[int], frequency: float):
        self.size = len(nodes)
        assert len(set(nodes)) == self.size, "All the nodes in %s should be distinct" % (nodes,)
        self.instance = instance
        self.node_indices = nodes
        self.nodes = [self.instance.index_to_node_label[i] for i in nodes]
        self.frequency = frequency

        # Note: As of this writing, self.weight is still considered a work in progress, so we
        # leave it out of the documentation.
        from mis.shared.graphs import BaseWeightPicker  # Avoid cycles.

        self.weight = BaseWeightPicker.for_weighting(Weighting.WEIGHTED).subgraph_weight(
            instance.graph, nodes
        )

    def draw(
        self,
        node_size: int = 600,
        highlight_color: str = "darkgreen",
        font_family: str = "serif",
    ) -> None:
        """
        Draw instance graph with solution nodes highlighted.

        Parameters:

            node_size (int): Size of drawn nodes in drawn graph. (default: 600)
            highlight_color (str): Color to highlight solution nodes with. (default: "darkgreen")
            font (str): Font type
        """
        self.instance.draw(self.node_indices, node_size, highlight_color, font_family)