Variational quantum eigensolver
We will demonstrate how to perform a Variational quantum eigensolver (VQE)1 task on a molecular example using Horqrux. VQE boils down to finding the molecular ground state \(| \psi(\theta) \rangle\) that minimizes the energy with respect to a molecular hamiltonian of interest denoted \(H\) : \(\(\langle \psi(\theta) | H | \psi(\theta) \rangle\)\)
Hamiltonian
In this example, we run VQE for the \(H2\) molecule in the STO-3G basis with a bondlength of \(0.742 \mathring{A}\)2. The groud-state energy is around \(-1.137\).
Note we need to manually create it by hand, as no syntax method is implemented (such methods are available in Qadence
though).
import jax
from jax import Array
import optax
import time
import horqrux
from horqrux import I, Z, X, Y
from horqrux import Scale, Observable
from horqrux.composite import OpSequence
from horqrux.api import expectation
from horqrux.circuit import hea, QuantumCircuit
H2_hamiltonian = Observable([
Scale(I(0), -0.09963387941370971),
Scale(Z(0), 0.17110545123720233),
Scale(Z(1), 0.17110545123720225),
Scale(OpSequence([Z(0) , Z(1)]), 0.16859349595532533),
Scale(OpSequence([Y(0) , X(1) , X(2) , Y(3)]), 0.04533062254573469),
Scale( OpSequence([Y(0) , Y(1) , X(2) , X(3)]) , -0.04533062254573469),
Scale( OpSequence([X(0) , X(1) , Y(2) , Y(3)]) , -0.04533062254573469),
Scale( OpSequence([X(0) , Y(1) , Y(2) , X(3)]), 0.04533062254573469),
Scale(Z(2),-0.22250914236600539),
Scale( OpSequence([Z(0) , Z(2)]), 0.12051027989546245),
Scale(Z(3), -0.22250914236600539),
Scale(OpSequence([Z(0) , Z(3)]), 0.16584090244119712),
Scale(OpSequence([Z(1) , Z(2)]), 0.16584090244119712),
Scale(OpSequence([Z(1) , Z(3)]), 0.12051027989546245),
Scale(OpSequence([Z(2) , Z(3)]), 0.1743207725924201),
])
Ansatz
As an ansatz, we use the hardware-efficient ansatz3 with \(5\) layers applied on the initial state \(| 0011 \rangle\).
Optimization with automatic differentiation
The objective here is to optimize the variational parameters of our ansatz using the standard Adam optimizer. Below we show how to set up a train function. We first consider the non-jitted version of the training function to compare later the timing with the jitted-version.
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))
LEARNING_RATE = 0.01
N_EPOCHS = 50
optimizer = optax.adam(learning_rate=LEARNING_RATE)
def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, param_vals)
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state
def loss_fn(param_vals: Array) -> Array:
"""The loss function is the sum of all expectation value for the observable components."""
values = dict(zip(ansatz.vparams, param_vals))
return jax.numpy.sum(expectation(init_state, ansatz, observables=[H2_hamiltonian], values=values))
def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
param_vals, opt_state = param_vals_opt_state
loss, grads = jax.value_and_grad(loss_fn)(param_vals)
return optimize_step(param_vals, opt_state, grads)
# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)
def train_unjitted(param_vals, opt_state):
for i in range(0, N_EPOCHS):
param_vals, opt_state = train_step(i, (param_vals, opt_state))
return param_vals, opt_state
start = time.time()
param_vals, opt_state = train_unjitted(param_vals, opt_state)
end = time.time()
time_nonjit = end - start
Now, we will jit the train_step
function with jax.lax.fori_loop
and improve execution time (expecting at least \(10\) times faster, depending on system):
# reset state and parameters
param_vals = init_param_vals.clone()
opt_state = optimizer.init(param_vals)
start_jit = time.time()
param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
end_jit = time.time()
time_jit = end_jit - start_jit
print(f"Time speedup: {time_nonjit / time_jit:.3f}")
Optimization with parameter-shift rule
When using parameter shift rule (PSR), we can either use the same expectation
using diff_mode=horqrux.DiffMode.GPSR
or the functions:
horqrux.differentiation.gpsr.jitted_analytical_exp
andhorqrux.differentiation.gpsr.jitted_finite_shots
as forward methodshorqrux.differentiation.gpsr.analytical_gpsr_bwd
andhorqrux.differentiation.gpsr.finite_shots_gpsr_backward
as backward methods.
Depending on the case, either way may be faster but the with the expectation
we can obtain higher-order derivatives.
Analytical
Let us rewrite our example using jitted_analytical_exp
and analytical_gpsr_bwd
for the analytical version of PSR:
from horqrux.differentiation.gpsr import jitted_analytical_exp, analytical_gpsr_bwd
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
ansatz_ops = list(iter(ansatz))
def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, param_vals)
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state
def loss_fn(param_vals: Array) -> Array:
"""The loss function is the sum of all expectation value for the observable components."""
values = dict(zip(ansatz.vparams, param_vals))
return jitted_analytical_exp(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values).sum()
def bwd_loss_fn(param_vals: Array) -> Array:
"""The backward returns directly the gradient vector via GPSR and `jitted_analytical_exp`."""
values = dict(zip(ansatz.vparams, param_vals))
return analytical_gpsr_bwd(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values)
def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
param_vals, opt_state = param_vals_opt_state
grads = bwd_loss_fn(param_vals)
return optimize_step(param_vals, opt_state, grads)
# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)
def train_unjitted(param_vals, opt_state):
for i in range(0, N_EPOCHS):
param_vals, opt_state = train_step(i, (param_vals, opt_state))
return param_vals, opt_state
param_vals, opt_state = train_unjitted(param_vals, opt_state)
With shots
Let us rewrite our example using jitted_finite_shots
and finite_shots_gpsr_backward
for the shot-based version of PSR:
from horqrux.differentiation.gpsr import jitted_finite_shots, finite_shots_gpsr_backward
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
init_param_vals = jax.random.uniform(key, shape=(ansatz.n_vparams,))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, param_vals)
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state
def loss_fn(param_vals: Array, key: jax.random.PRNGKey) -> Array:
"""The loss function is the sum of all expectation value for the observable components."""
values = dict(zip(ansatz.vparams, param_vals))
return jitted_finite_shots(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values, n_shots=10000, key=key).sum()
def bwd_loss_fn(param_vals: Array, key: jax.random.PRNGKey) -> Array:
values = dict(zip(ansatz.vparams, param_vals))
return finite_shots_gpsr_backward(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values, n_shots=10000, key=key)
def train_step(i: int, param_vals_opt_state: tuple) -> tuple:
param_vals, opt_state = param_vals_opt_state
grads = bwd_loss_fn(param_vals, jax.random.PRNGKey(i))
return optimize_step(param_vals, opt_state, grads)
# set initial parameters and the state of the optimizer
param_vals = init_param_vals.clone()
opt_state = optimizer.init(init_param_vals)
def train_unjitted(param_vals, opt_state):
for i in range(0, N_EPOCHS):
param_vals, opt_state = train_step(i, (param_vals, opt_state))
return param_vals, opt_state
param_vals, opt_state = train_unjitted(param_vals, opt_state)
def analytical_expectation(param_vals: Array) -> Array:
values = dict(zip(ansatz.vparams, param_vals))
return jitted_analytical_exp(init_state, ansatz_ops, observables=[H2_hamiltonian], values=values).sum()