Skip to content

Classification with QNN

In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.

Dataset

We will use the Iris dataset separated into training and testing sets. The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica). When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).

import random

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from qadence import QNN, RX, FeatureParameter, QuantumCircuit, Z, chain, hea, kron
from qadence.ml_tools import TrainConfig, Trainer

class IrisDataset(Dataset):
    """The Iris dataset split into a training set and a test set.

    A StandardScaler is applied prior to applying models.
    """

    def __init__(self):
        X, y = load_iris(return_X_y=True)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

        self.scaler = StandardScaler()
        self.scaler.fit(X_train)
        self.X = torch.tensor(self.scaler.transform(X_train), requires_grad=False)
        self.y = torch.tensor(y_train, requires_grad=False)

        self.X_test = torch.tensor(self.scaler.transform(X_test), requires_grad=False)
        self.y_test = torch.tensor(y_test, requires_grad=False)

    def __getitem__(self, index) -> tuple[Tensor, Tensor]:
        return self.X[index], self.y[index]

    def __len__(self) -> int:
        return len(self.y)

n_features = 4  # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset()

dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

Hybrid QNN

We set up the QNN part composed of multiple feature map layers, each followed by a variational layer. The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components. The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\). Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\), \(l = W * o + b\). To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\). Note softmax is not applied during training with the cross-entropy loss.

feature_parameters = [FeatureParameter(f"x_{i}") for i in range(n_features)]
fm_layer = RX(0, feature_parameters[0])
for q in range(1, n_features):
    fm_layer = kron(fm_layer, RX(q, feature_parameters[q]))

ansatz_layers = [
    hea(n_qubits=n_features, depth=1, param_prefix=f"theta_{layer}")
    for layer in range(n_layers)
]
blocks = chain(fm_layer, ansatz_layers[0])
for layer in range(1, n_layers):
    blocks = chain(blocks, fm_layer, ansatz_layers[layer])

qc = QuantumCircuit(n_features, blocks)
qnn = QNN(circuit=qc, observable=Z(0), inputs=[f"x_{i}" for i in range(n_features)])
model = nn.Sequential(qnn, nn.Linear(1, n_neurons_final_linear_layer))

Below is a visualization of the QNN:


%3 cluster_8fcd288cca124756b9f3b6b63500edf8 HEA cluster_5666bc7191b7487b989d2ebf24b16cf9 Obs. cluster_2cbb695eaa4b4cf59e732ad896adc43c HEA cluster_690441ddf9f042b98236b8797e8d7614 HEA de68d67795a74aa1b1ed31d16d376fa9 0 c896849c4f6d4a5ca56c9b135f8cbd08 RX(x₀) de68d67795a74aa1b1ed31d16d376fa9--c896849c4f6d4a5ca56c9b135f8cbd08 bcd489b052664fea98238832fa8f8261 1 c72da977458f4f9cad5a81d99f45e21a RX(theta₀₀) c896849c4f6d4a5ca56c9b135f8cbd08--c72da977458f4f9cad5a81d99f45e21a a30d5993777841a59dba55a97e5ca60c RY(theta₀₄) c72da977458f4f9cad5a81d99f45e21a--a30d5993777841a59dba55a97e5ca60c 4181ff1b46c445c4ae854ca7a0dec569 RX(theta₀₈) a30d5993777841a59dba55a97e5ca60c--4181ff1b46c445c4ae854ca7a0dec569 106f4c71972b46cea4caadfa84e5d756 4181ff1b46c445c4ae854ca7a0dec569--106f4c71972b46cea4caadfa84e5d756 b8bc42e0869c488695d8cffe31d140d4 106f4c71972b46cea4caadfa84e5d756--b8bc42e0869c488695d8cffe31d140d4 dc605fa49c3641df913391c483728d38 RX(x₀) b8bc42e0869c488695d8cffe31d140d4--dc605fa49c3641df913391c483728d38 fa06f891fe924b5083b1b7ed3a7daea8 RX(theta₁₀) dc605fa49c3641df913391c483728d38--fa06f891fe924b5083b1b7ed3a7daea8 01ed9c23de964b3aa1bca2522bee4de0 RY(theta₁₄) fa06f891fe924b5083b1b7ed3a7daea8--01ed9c23de964b3aa1bca2522bee4de0 25eeda31b2984ae1a8012d28347b847b RX(theta₁₈) 01ed9c23de964b3aa1bca2522bee4de0--25eeda31b2984ae1a8012d28347b847b ae4c7cd38bdf49de8da34ab90c21d580 25eeda31b2984ae1a8012d28347b847b--ae4c7cd38bdf49de8da34ab90c21d580 cfbb2103d0f748e9877d74d62c56625f ae4c7cd38bdf49de8da34ab90c21d580--cfbb2103d0f748e9877d74d62c56625f 025612bc71e74df1941f6576b85cf0d8 RX(x₀) cfbb2103d0f748e9877d74d62c56625f--025612bc71e74df1941f6576b85cf0d8 06f90a14c3f64bb686699726faaac584 RX(theta₂₀) 025612bc71e74df1941f6576b85cf0d8--06f90a14c3f64bb686699726faaac584 bcfed19f1eab4cf4baf484ffd42a9e9a RY(theta₂₄) 06f90a14c3f64bb686699726faaac584--bcfed19f1eab4cf4baf484ffd42a9e9a 70456408e37548f480b115011307557a RX(theta₂₈) bcfed19f1eab4cf4baf484ffd42a9e9a--70456408e37548f480b115011307557a b7d30ce450204e16bc5fc086675646f4 70456408e37548f480b115011307557a--b7d30ce450204e16bc5fc086675646f4 bc5ea024e44e442196b7e2227499a281 b7d30ce450204e16bc5fc086675646f4--bc5ea024e44e442196b7e2227499a281 2770fc3cfd3b4feeb5cd2dd8113a2500 Z bc5ea024e44e442196b7e2227499a281--2770fc3cfd3b4feeb5cd2dd8113a2500 b11d81bfc7484ef588c0fd236b6b84f5 2770fc3cfd3b4feeb5cd2dd8113a2500--b11d81bfc7484ef588c0fd236b6b84f5 d2ce00a8fbd04248816f2f08fcf92933 a990b5a20d474ad09d77157fa3ab0732 RX(x₁) bcd489b052664fea98238832fa8f8261--a990b5a20d474ad09d77157fa3ab0732 9a4e0ab0b2814b22af3272fb285c153b 2 dbc5df7b3f6c47048646d6f3da6b6e92 RX(theta₀₁) a990b5a20d474ad09d77157fa3ab0732--dbc5df7b3f6c47048646d6f3da6b6e92 945cb00e10b444fb81820164894b03f2 RY(theta₀₅) dbc5df7b3f6c47048646d6f3da6b6e92--945cb00e10b444fb81820164894b03f2 16fd9b8f38fc420687858f527f8f67b6 RX(theta₀₉) 945cb00e10b444fb81820164894b03f2--16fd9b8f38fc420687858f527f8f67b6 a5375a3195484c23ab1a843e31edb846 X 16fd9b8f38fc420687858f527f8f67b6--a5375a3195484c23ab1a843e31edb846 a5375a3195484c23ab1a843e31edb846--106f4c71972b46cea4caadfa84e5d756 ed8f20868f2d483ca95c22ec7ef23d65 a5375a3195484c23ab1a843e31edb846--ed8f20868f2d483ca95c22ec7ef23d65 be0e4af36f4648699222104567d31ecb RX(x₁) ed8f20868f2d483ca95c22ec7ef23d65--be0e4af36f4648699222104567d31ecb e95e72dd9a14419697d00f3e2e58f9e2 RX(theta₁₁) be0e4af36f4648699222104567d31ecb--e95e72dd9a14419697d00f3e2e58f9e2 1c157527831841058d989e65e9d5cf1c RY(theta₁₅) e95e72dd9a14419697d00f3e2e58f9e2--1c157527831841058d989e65e9d5cf1c f6f4db406e5d44f794d726f618c39b25 RX(theta₁₉) 1c157527831841058d989e65e9d5cf1c--f6f4db406e5d44f794d726f618c39b25 346f357a149a40a5b96162e4251bbb78 X f6f4db406e5d44f794d726f618c39b25--346f357a149a40a5b96162e4251bbb78 346f357a149a40a5b96162e4251bbb78--ae4c7cd38bdf49de8da34ab90c21d580 12ef1af0b16c40da9839c73d53017662 346f357a149a40a5b96162e4251bbb78--12ef1af0b16c40da9839c73d53017662 2384da870e65493db0396db9972ee0f2 RX(x₁) 12ef1af0b16c40da9839c73d53017662--2384da870e65493db0396db9972ee0f2 04a9b3e8ef5d484b8fc5d525959cf0a0 RX(theta₂₁) 2384da870e65493db0396db9972ee0f2--04a9b3e8ef5d484b8fc5d525959cf0a0 78a526b9a1274366a167dfdf7496bd0a RY(theta₂₅) 04a9b3e8ef5d484b8fc5d525959cf0a0--78a526b9a1274366a167dfdf7496bd0a aa96579f744c4a31999989fd0d7fd4ec RX(theta₂₉) 78a526b9a1274366a167dfdf7496bd0a--aa96579f744c4a31999989fd0d7fd4ec 5244f680655945318507d8c9f9f68d02 X aa96579f744c4a31999989fd0d7fd4ec--5244f680655945318507d8c9f9f68d02 5244f680655945318507d8c9f9f68d02--b7d30ce450204e16bc5fc086675646f4 960688474e974410ba0a073465a7bf4d 5244f680655945318507d8c9f9f68d02--960688474e974410ba0a073465a7bf4d 21815e3f9fc4429e87fe68ce59d95e8f 960688474e974410ba0a073465a7bf4d--21815e3f9fc4429e87fe68ce59d95e8f 21815e3f9fc4429e87fe68ce59d95e8f--d2ce00a8fbd04248816f2f08fcf92933 4e2705e9e2464844bc9fdced97bda8aa 466d7f064d02492ea876d8ef3378fb4d RX(x₂) 9a4e0ab0b2814b22af3272fb285c153b--466d7f064d02492ea876d8ef3378fb4d d7cd9af93e87469aa12b799e41005e9a 3 10cc4ccd3044452d8d19c75a7ff33715 RX(theta₀₂) 466d7f064d02492ea876d8ef3378fb4d--10cc4ccd3044452d8d19c75a7ff33715 d86e6891142f460e8ea6bdc41242d4ac RY(theta₀₆) 10cc4ccd3044452d8d19c75a7ff33715--d86e6891142f460e8ea6bdc41242d4ac 279d7dc7b3ec4cdba1eb7fa8ec66fa3d RX(theta₀₁₀) d86e6891142f460e8ea6bdc41242d4ac--279d7dc7b3ec4cdba1eb7fa8ec66fa3d 1261ca806bfd437c86273c8eceb2f0db 279d7dc7b3ec4cdba1eb7fa8ec66fa3d--1261ca806bfd437c86273c8eceb2f0db deac9f0e5d294cf791ca98b6220be11b X 1261ca806bfd437c86273c8eceb2f0db--deac9f0e5d294cf791ca98b6220be11b deac9f0e5d294cf791ca98b6220be11b--ed8f20868f2d483ca95c22ec7ef23d65 98a50295980142c081b0d79c2ec6ca36 RX(x₂) deac9f0e5d294cf791ca98b6220be11b--98a50295980142c081b0d79c2ec6ca36 361e8ab63d46475abedd0cac77f297bd RX(theta₁₂) 98a50295980142c081b0d79c2ec6ca36--361e8ab63d46475abedd0cac77f297bd 9b5130bce0c34f3fbe59cb070e16cc61 RY(theta₁₆) 361e8ab63d46475abedd0cac77f297bd--9b5130bce0c34f3fbe59cb070e16cc61 4e795de2d21e48eabba317afad43dd11 RX(theta₁₁₀) 9b5130bce0c34f3fbe59cb070e16cc61--4e795de2d21e48eabba317afad43dd11 f2f25f26760345ad8b433a65d845e000 4e795de2d21e48eabba317afad43dd11--f2f25f26760345ad8b433a65d845e000 5dfcc9bfb4124479bf40c878f4ccd6e9 X f2f25f26760345ad8b433a65d845e000--5dfcc9bfb4124479bf40c878f4ccd6e9 5dfcc9bfb4124479bf40c878f4ccd6e9--12ef1af0b16c40da9839c73d53017662 2a2cb7b1831443c0a29714a82e8ff278 RX(x₂) 5dfcc9bfb4124479bf40c878f4ccd6e9--2a2cb7b1831443c0a29714a82e8ff278 5d9d47edb2a6481596c575220c62b345 RX(theta₂₂) 2a2cb7b1831443c0a29714a82e8ff278--5d9d47edb2a6481596c575220c62b345 daaacfabfdae4da086a2327c487c6887 RY(theta₂₆) 5d9d47edb2a6481596c575220c62b345--daaacfabfdae4da086a2327c487c6887 8674421fe19142039f55eeaf113d8b49 RX(theta₂₁₀) daaacfabfdae4da086a2327c487c6887--8674421fe19142039f55eeaf113d8b49 cfba99c17f2841c4a7f53504fff36ebc 8674421fe19142039f55eeaf113d8b49--cfba99c17f2841c4a7f53504fff36ebc 614455691c9b41648c178f3cbb706759 X cfba99c17f2841c4a7f53504fff36ebc--614455691c9b41648c178f3cbb706759 614455691c9b41648c178f3cbb706759--960688474e974410ba0a073465a7bf4d c51c8c2b139640638c9c4022c9363f50 614455691c9b41648c178f3cbb706759--c51c8c2b139640638c9c4022c9363f50 c51c8c2b139640638c9c4022c9363f50--4e2705e9e2464844bc9fdced97bda8aa db0a6ccf8809490c9184445322ed8159 e7603871a9a04fcd9db6c989c3e41b7f RX(x₃) d7cd9af93e87469aa12b799e41005e9a--e7603871a9a04fcd9db6c989c3e41b7f 0db5f872996048f48405a41d948243f1 RX(theta₀₃) e7603871a9a04fcd9db6c989c3e41b7f--0db5f872996048f48405a41d948243f1 96cc1adf86514994a5eab0493b504f4c RY(theta₀₇) 0db5f872996048f48405a41d948243f1--96cc1adf86514994a5eab0493b504f4c 0b6c3010f2c94ef5a1f11388462a8e7c RX(theta₀₁₁) 96cc1adf86514994a5eab0493b504f4c--0b6c3010f2c94ef5a1f11388462a8e7c aa93afb20e654b0fa84d856ade6c1101 X 0b6c3010f2c94ef5a1f11388462a8e7c--aa93afb20e654b0fa84d856ade6c1101 aa93afb20e654b0fa84d856ade6c1101--1261ca806bfd437c86273c8eceb2f0db 22bf515062e64f469772b1675fcc4572 aa93afb20e654b0fa84d856ade6c1101--22bf515062e64f469772b1675fcc4572 52bea60b596e4f1595ef3002ab7b8d20 RX(x₃) 22bf515062e64f469772b1675fcc4572--52bea60b596e4f1595ef3002ab7b8d20 c2de3d93282e4e7ca2678ec326da02d6 RX(theta₁₃) 52bea60b596e4f1595ef3002ab7b8d20--c2de3d93282e4e7ca2678ec326da02d6 03f0bdbd232746bea178a131b5c155a4 RY(theta₁₇) c2de3d93282e4e7ca2678ec326da02d6--03f0bdbd232746bea178a131b5c155a4 3808d4a958914b3f99b244217342e314 RX(theta₁₁₁) 03f0bdbd232746bea178a131b5c155a4--3808d4a958914b3f99b244217342e314 0d6221570a8847b59c7291bfac83a4ca X 3808d4a958914b3f99b244217342e314--0d6221570a8847b59c7291bfac83a4ca 0d6221570a8847b59c7291bfac83a4ca--f2f25f26760345ad8b433a65d845e000 e1da5a369c7f4e639fdc2097961d5066 0d6221570a8847b59c7291bfac83a4ca--e1da5a369c7f4e639fdc2097961d5066 3ce8d1fe357b48808df47f20dd5e9e55 RX(x₃) e1da5a369c7f4e639fdc2097961d5066--3ce8d1fe357b48808df47f20dd5e9e55 302dab50d7734c968cf4f39a1f429c7b RX(theta₂₃) 3ce8d1fe357b48808df47f20dd5e9e55--302dab50d7734c968cf4f39a1f429c7b 005ab65f823c40fcaa8ad8d4f5c4a093 RY(theta₂₇) 302dab50d7734c968cf4f39a1f429c7b--005ab65f823c40fcaa8ad8d4f5c4a093 a0b8a3f3bf21448fbb59b86d2a616c60 RX(theta₂₁₁) 005ab65f823c40fcaa8ad8d4f5c4a093--a0b8a3f3bf21448fbb59b86d2a616c60 d5ef99affd064d3f9a2c436cb15b99c7 X a0b8a3f3bf21448fbb59b86d2a616c60--d5ef99affd064d3f9a2c436cb15b99c7 d5ef99affd064d3f9a2c436cb15b99c7--cfba99c17f2841c4a7f53504fff36ebc 5575f9aaa4104db09cb2735dd17fed91 d5ef99affd064d3f9a2c436cb15b99c7--5575f9aaa4104db09cb2735dd17fed91 93e8d6ad3c744fcd81194b2073880ebf 5575f9aaa4104db09cb2735dd17fed91--93e8d6ad3c744fcd81194b2073880ebf 93e8d6ad3c744fcd81194b2073880ebf--db0a6ccf8809490c9184445322ed8159

Training

Then we can set up the training part:

opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

def cross_entropy(model: nn.Module, data: Tensor) -> tuple[Tensor, dict]:
    x, y = data
    out = model(x)
    loss = criterion(out, y)
    return loss, {}

train_config = TrainConfig(max_iter=n_epochs, print_every=10, create_subfolder_per_run=True)
Trainer.set_use_grad(True)
trainer = Trainer(model=model, optimizer=opt, config=train_config, loss_fn=cross_entropy)


res_train = trainer.fit(dataloader)

Inference

Finally, we can apply our model on the test set and check the score.

X_test, y_test = dataset.X_test, dataset.y_test
preds_test = torch.argmax(torch.softmax(model(X_test), dim=1), dim=1)
accuracy_test = (preds_test == y_test).type(torch.float32).mean()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005