Skip to content

Variational quantum eigensolver

We will demonstrate how to perform a Variational quantum eigensolver (VQE)1 task on a molecular example using Horqrux. VQE boils down to finding the molecular ground state \(| \psi(\theta) \rangle\) that minimizes the energy with respect to a molecular hamiltonian of interest denoted \(H\) : \(\(\langle \psi(\theta) | H | \psi(\theta) \rangle\)\)

Hamiltonian

In this example, we run VQE for the \(H2\) molecule in the STO-3G basis with a bondlength of \(0.742 \mathring{A}\)2. The groud-state energy is around \(-1.137\). Note we need to manually create it by hand, as no syntax method is implemented (such methods are available in Qadence though).

import jax
from jax import Array
import optax

import time

import horqrux
from horqrux import I, Z, X, Y
from horqrux import Scale, Observable
from horqrux.composite import OpSequence

from horqrux.api import expectation
from horqrux.circuit import hea, QuantumCircuit

H2_hamiltonian = Observable([
  Scale(I(0), -0.09963387941370971),
  Scale(Z(0), 0.17110545123720233),
  Scale(Z(1), 0.17110545123720225),
  Scale(OpSequence([Z(0) , Z(1)]), 0.16859349595532533),
  Scale(OpSequence([Y(0) , X(1) , X(2) , Y(3)]), 0.04533062254573469),
  Scale( OpSequence([Y(0) , Y(1) , X(2) , X(3)]) , -0.04533062254573469),
  Scale( OpSequence([X(0) , X(1) , Y(2) , Y(3)]) , -0.04533062254573469),
  Scale( OpSequence([X(0) , Y(1) , Y(2) , X(3)]),  0.04533062254573469),
  Scale(Z(2),-0.22250914236600539),
  Scale( OpSequence([Z(0) , Z(2)]), 0.12051027989546245),
  Scale(Z(3), -0.22250914236600539),
  Scale(OpSequence([Z(0) , Z(3)]), 0.16584090244119712),
  Scale(OpSequence([Z(1) , Z(2)]), 0.16584090244119712),
  Scale(OpSequence([Z(1) , Z(3)]), 0.12051027989546245),
  Scale(OpSequence([Z(2) , Z(3)]), 0.1743207725924201),
])

Ansatz

As an ansatz, we use the hardware-efficient ansatz3 with \(5\) layers applied on the initial state \(| 0011 \rangle\).

init_state = horqrux.product_state("0011")
ansatz = QuantumCircuit(4, hea(4, 5))
Number of variational parameters: 60

Optimization

The objective here is to optimize the variational parameters of our ansatz using the standard Adam optimizer. Below we show how to set up a train function. We first consider the non-jitted version of the training function to compare later the timing with the jitted-version.

# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))
LEARNING_RATE = 0.01
N_EPOCHS = 50

optimizer = optax.adam(learning_rate=LEARNING_RATE)

def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
    updates, opt_state = optimizer.update(grads, opt_state, param_vals)
    param_vals = optax.apply_updates(param_vals, updates)
    return param_vals, opt_state

def loss_fn(param_vals: Array) -> Array:
    """The loss function is the sum of all expectation value for the observable components."""
    values = dict(zip(ansatz.vparams, param_vals))
    return jax.numpy.sum(expectation(init_state, ansatz, observables=[H2_hamiltonian], values=values))


def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
    param_vals, opt_state = param_vals_opt_state
    loss, grads = jax.value_and_grad(loss_fn)(param_vals)
    return optimize_step(param_vals, opt_state, grads)

# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)

def train_unjitted(param_vals, opt_state):
    for i in range(0, N_EPOCHS):
        param_vals, opt_state = train_step(i, (param_vals, opt_state))
    return param_vals, opt_state

start = time.time()
param_vals, opt_state = train_unjitted(param_vals, opt_state)
end = time.time()
time_nonjit = end - start
Initial loss: -0.176 Final loss: -1.040

Now, we will jit the train_step function with jax.lax.fori_loop and improve execution time (expecting at least \(10\) times faster, depending on system):

# reset state and parameters
param_vals = init_param_vals.clone()
opt_state = optimizer.init(param_vals)

start_jit = time.time()
param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
end_jit = time.time()
time_jit = end_jit - start_jit

print(f"Time speedup: {time_nonjit / time_jit:.3f}")
Time speedup: 15.058