Skip to content

Variational quantum eigensolver

We will demonstrate how to perform a Variational quantum eigensolver (VQE)1 task on a molecular example using Horqrux. VQE boils down to finding the molecular ground state \(| \psi(\theta) \rangle\) that minimizes the energy with respect to a molecular hamiltonian of interest denoted \(H\) : \(\(\langle \psi(\theta) | H | \psi(\theta) \rangle\)\)

Hamiltonian

In this example, we run VQE for the \(H2\) molecule in the STO-3G basis with a bondlength of \(0.742 \mathring{A}\)2. The groud-state energy is around \(-1.137\). Note we need to manually create it by hand, as no syntax method is implemented (such methods are available in Qadence though).

import jax
from jax import Array
import optax

import time

import horqrux
from horqrux import I, Z, X, Y
from horqrux import Scale, Observable
from horqrux.composite import OpSequence

from horqrux.api import expectation
from horqrux.circuit import hea, QuantumCircuit

H2_hamiltonian = Observable([
  Scale(I(0), -0.09963387941370971),
  Scale(Z(0), 0.17110545123720233),
  Scale(Z(1), 0.17110545123720225),
  Scale(OpSequence([Z(0) , Z(1)]), 0.16859349595532533),
  Scale(OpSequence([Y(0) , X(1) , X(2) , Y(3)]), 0.04533062254573469),
  Scale( OpSequence([Y(0) , Y(1) , X(2) , X(3)]) , -0.04533062254573469),
  Scale( OpSequence([X(0) , X(1) , Y(2) , Y(3)]) , -0.04533062254573469),
  Scale( OpSequence([X(0) , Y(1) , Y(2) , X(3)]),  0.04533062254573469),
  Scale(Z(2),-0.22250914236600539),
  Scale( OpSequence([Z(0) , Z(2)]), 0.12051027989546245),
  Scale(Z(3), -0.22250914236600539),
  Scale(OpSequence([Z(0) , Z(3)]), 0.16584090244119712),
  Scale(OpSequence([Z(1) , Z(2)]), 0.16584090244119712),
  Scale(OpSequence([Z(1) , Z(3)]), 0.12051027989546245),
  Scale(OpSequence([Z(2) , Z(3)]), 0.1743207725924201),
])

Ansatz

As an ansatz, we use the hardware-efficient ansatz3 with \(5\) layers applied on the initial state \(| 0011 \rangle\).

init_state = horqrux.product_state("0011")
ansatz = QuantumCircuit(4, hea(4, 5))
Number of variational parameters: 60

Optimization with automatic differentiation

The objective here is to optimize the variational parameters of our ansatz using the standard Adam optimizer. Below we show how to set up a train function. We first consider the non-jitted version of the training function to compare later the timing with the jitted-version.

# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))
LEARNING_RATE = 0.01
N_EPOCHS = 50

optimizer = optax.adam(learning_rate=LEARNING_RATE)

def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
    updates, opt_state = optimizer.update(grads, opt_state, param_vals)
    param_vals = optax.apply_updates(param_vals, updates)
    return param_vals, opt_state

def loss_fn(param_vals: Array) -> Array:
    """The loss function is the sum of all expectation value for the observable components."""
    values = dict(zip(ansatz.vparams, param_vals))
    return jax.numpy.sum(expectation(init_state, ansatz, observables=[H2_hamiltonian], values=values))


def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
    param_vals, opt_state = param_vals_opt_state
    loss, grads = jax.value_and_grad(loss_fn)(param_vals)
    return optimize_step(param_vals, opt_state, grads)

# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)

def train_unjitted(param_vals, opt_state):
    for i in range(0, N_EPOCHS):
        param_vals, opt_state = train_step(i, (param_vals, opt_state))
    return param_vals, opt_state

start = time.time()
param_vals, opt_state = train_unjitted(param_vals, opt_state)
end = time.time()
time_nonjit = end - start
Initial loss: -0.247 Final loss: -1.051

Now, we will jit the train_step function with jax.lax.fori_loop and improve execution time (expecting at least \(10\) times faster, depending on system):

# reset state and parameters
param_vals = init_param_vals.clone()
opt_state = optimizer.init(param_vals)

start_jit = time.time()
param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
end_jit = time.time()
time_jit = end_jit - start_jit

print(f"Time speedup: {time_nonjit / time_jit:.3f}")
Time speedup: 13.478

Optimization with parameter-shift rule

When using parameter shift rule (PSR), we can either use the same expectation using diff_mode=horqrux.DiffMode.GPSR or the functions:

  • horqrux.differentiation.gpsr.jitted_analytical_exp and horqrux.differentiation.gpsr.jitted_finite_shots as forward methods
  • horqrux.differentiation.gpsr.analytical_gpsr_bwd and horqrux.differentiation.gpsr.finite_shots_gpsr_backward as backward methods.

Depending on the case, either way may be faster but the with the expectation we can obtain higher-order derivatives.

Analytical

Let us rewrite our example using jitted_analytical_exp and analytical_gpsr_bwd for the analytical version of PSR:

from horqrux.differentiation.gpsr import jitted_analytical_exp, analytical_gpsr_bwd

# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))

optimizer = optax.adam(learning_rate=LEARNING_RATE)
ansatz_ops = list(iter(ansatz))

def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
    updates, opt_state = optimizer.update(grads, opt_state, param_vals)
    param_vals = optax.apply_updates(param_vals, updates)
    return param_vals, opt_state

def loss_fn(param_vals: Array) -> Array:
    """The loss function is the sum of all expectation value for the observable components."""
    values = dict(zip(ansatz.vparams, param_vals))
    return jitted_analytical_exp(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values).sum()

def bwd_loss_fn(param_vals: Array) -> Array:
    """The backward returns directly the gradient vector via GPSR and `jitted_analytical_exp`."""
    values = dict(zip(ansatz.vparams, param_vals))
    return analytical_gpsr_bwd(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values)

def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
    param_vals, opt_state = param_vals_opt_state
    grads = bwd_loss_fn(param_vals)
    return optimize_step(param_vals, opt_state, grads)

# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)

def train_unjitted(param_vals, opt_state):
    for i in range(0, N_EPOCHS):
        param_vals, opt_state = train_step(i, (param_vals, opt_state))
    return param_vals, opt_state

param_vals, opt_state = train_unjitted(param_vals, opt_state)
Final loss: -1.051

With shots

Let us rewrite our example using jitted_finite_shots and finite_shots_gpsr_backward for the shot-based version of PSR:

from horqrux.differentiation.gpsr import jitted_finite_shots, finite_shots_gpsr_backward


# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))

optimizer = optax.adam(learning_rate=LEARNING_RATE)

def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
    updates, opt_state = optimizer.update(grads, opt_state, param_vals)
    param_vals = optax.apply_updates(param_vals, updates)
    return param_vals, opt_state

def loss_fn(param_vals: Array, key: jax.random.PRNGKey) -> Array:
    """The loss function is the sum of all expectation value for the observable components."""
    values = dict(zip(ansatz.vparams, param_vals))
    return jitted_finite_shots(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values, n_shots=10000, key=key).sum()

def bwd_loss_fn(param_vals: Array, key: jax.random.PRNGKey) -> Array:
    values = dict(zip(ansatz.vparams, param_vals))
    return finite_shots_gpsr_backward(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values, n_shots=10000, key=key)

def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
    param_vals, opt_state = param_vals_opt_state
    grads = bwd_loss_fn(param_vals, jax.random.PRNGKey(i))
    return optimize_step(param_vals, opt_state, grads)

# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)

def train_unjitted(param_vals, opt_state):
    for i in range(0, N_EPOCHS):
        param_vals, opt_state = train_step(i, (param_vals, opt_state))
    return param_vals, opt_state

param_vals, opt_state = train_unjitted(param_vals, opt_state)

def analytical_expectation(param_vals: Array) -> Array:
    values = dict(zip(ansatz.vparams, param_vals))
    return jitted_analytical_exp(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values).sum()
Final loss: -1.050