Skip to content

mis.pipeline.pulse

[docs] module mis.pipeline.pulse

from __future__ import annotations

from dataclasses import dataclass
from abc import ABC, abstractmethod

from networkx.classes.reportviews import DegreeView
from pulser import InterpolatedWaveform, Pulse as PulserPulse

from mis.pipeline.config import SolverConfig

import numpy as np
import networkx as nx
from scipy.spatial.distance import euclidean

from .targets import Pulse, Register


@dataclass
class BasePulseShaper(ABC):
    """
    Abstract base class for generating pulse schedules based on a MIS problem.

    This class transforms the structure of a MISInstance into a quantum
    pulse sequence that can be applied to a physical register. The register
    is passed at the time of pulse generation, not during initialization.
    """

    duration_us: int | None = None
    """The duration of the pulse, in microseconds.

    If unspecified, use the maximal duration for the device."""

    @abstractmethod
    def generate(self, config: SolverConfig, register: Register) -> Pulse:
        """
        Generate a pulse based on the problem and the provided register.

        Args:
            config: The configuration for this pulse.
            register: The physical register layout.

        Returns:
            Pulse: A generated pulse object wrapping a Pulser pulse.
        """
        pass


@dataclass
class _Bounds:
    maximum_amplitude: float
    final_detuning: float


class DefaultPulseShaper(BasePulseShaper):
    """
    A simple pulse shaper.
    """

    def generate(self, config: SolverConfig, register: Register) -> Pulse:
        """
        Return a simple constant waveform pulse
        """

        device = config.device
        assert device is not None

        # Cache mapping node value -> node index.
        index_by_node = {node: i for (i, node) in enumerate(register.graph.nodes)}

        graph = register.graph
        pos = register.register.sorted_coords

        def calculate_edge_interaction(edge: tuple[int, int]) -> float:
            pos_a, pos_b = pos[index_by_node[edge[0]]], pos[index_by_node[edge[1]]]
            return float(device.interaction_coeff / (euclidean(pos_a, pos_b) ** 6))

        # Interaction strength for connected nodes.
        connected = [calculate_edge_interaction(edge) for edge in graph.edges()]

        # Interaction strength for disconnected nodes.
        disconnected = [calculate_edge_interaction(edge) for edge in nx.complement(graph).edges()]

        # Determine the minimal energy between two connected nodes.
        if len(connected) == 0:
            u_min = 0
        else:
            u_min = np.min(connected)

        # Determine the maximal energy between two disconnected nodes.
        if len(disconnected) == 0:
            u_max = np.inf
        else:
            u_max = np.max(disconnected)

        max_amp_device = device.channels["rydberg_global"].max_amp or np.inf
        maximum_amplitude = min(max_amp_device, u_max + 0.8 * (u_min - u_max))
        # FIXME: Why 0.8?

        # Compute min/max degrees
        degree = register.graph.degree
        assert isinstance(degree, DegreeView)
        d_min = None
        d_max = None
        for _, deg in degree:
            assert isinstance(deg, int)
            if d_min is None or deg < d_min:
                d_min = deg
            if d_max is None or deg > d_max:
                d_max = deg
        assert d_min is not None
        assert d_max is not None
        assert isinstance(d_min, int)
        assert isinstance(d_max, int)
        det_max_theory = (d_min / (d_min + 1)) * u_min
        det_min_theory = sum(sorted(disconnected)[-d_max:])
        det_final_theory = max(det_max_theory, det_min_theory)
        det_max_device = device.channels["rydberg_global"].max_abs_detuning or np.inf
        final_detuning = min(det_final_theory, det_max_device)

        duration_us = self.duration_us
        if duration_us is None:
            duration_us = device.max_sequence_duration

        amplitude = InterpolatedWaveform(
            duration_us, [1e-9, maximum_amplitude, 1e-9]
        )  # FIXME: This should be 0, investigate why it's 1e-9
        detuning = InterpolatedWaveform(duration_us, [-final_detuning, 0, final_detuning])
        rydberg_pulse = PulserPulse(amplitude, detuning, 0)
        # Pulser overrides PulserPulse.__new__ with an exotic type, so we need
        # to help mypy.
        assert isinstance(rydberg_pulse, PulserPulse)

        return Pulse(pulse=rydberg_pulse)