Skip to content

Classification with QNN

In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.

Dataset

We will use the Iris dataset separated into training and testing sets. The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica). When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).

import random

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from qadence import QNN, RX, FeatureParameter, QuantumCircuit, Z, chain, hea, kron
from qadence.ml_tools import TrainConfig, Trainer

class IrisDataset(Dataset):
    """The Iris dataset split into a training set and a test set.

    A StandardScaler is applied prior to applying models.
    """

    def __init__(self):
        X, y = load_iris(return_X_y=True)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

        self.scaler = StandardScaler()
        self.scaler.fit(X_train)
        self.X = torch.tensor(self.scaler.transform(X_train), requires_grad=False)
        self.y = torch.tensor(y_train, requires_grad=False)

        self.X_test = torch.tensor(self.scaler.transform(X_test), requires_grad=False)
        self.y_test = torch.tensor(y_test, requires_grad=False)

    def __getitem__(self, index) -> tuple[Tensor, Tensor]:
        return self.X[index], self.y[index]

    def __len__(self) -> int:
        return len(self.y)

n_features = 4  # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset()

dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

Hybrid QNN

We set up the QNN part composed of multiple feature map layers, each followed by a variational layer. The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components. The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\). Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\), \(l = W * o + b\). To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\). Note softmax is not applied during training with the cross-entropy loss.

feature_parameters = [FeatureParameter(f"x_{i}") for i in range(n_features)]
fm_layer = RX(0, feature_parameters[0])
for q in range(1, n_features):
    fm_layer = kron(fm_layer, RX(q, feature_parameters[q]))

ansatz_layers = [
    hea(n_qubits=n_features, depth=1, param_prefix=f"theta_{layer}")
    for layer in range(n_layers)
]
blocks = chain(fm_layer, ansatz_layers[0])
for layer in range(1, n_layers):
    blocks = chain(blocks, fm_layer, ansatz_layers[layer])

qc = QuantumCircuit(n_features, blocks)
qnn = QNN(circuit=qc, observable=Z(0), inputs=[f"x_{i}" for i in range(n_features)])
model = nn.Sequential(qnn, nn.Linear(1, n_neurons_final_linear_layer))

Below is a visualization of the QNN:


%3 cluster_4520a4655b6a4883bcf93f9be901fe60 HEA cluster_f61d0e2136f64afea30d57129a1f57d8 Obs. cluster_4d8d9194641d4ec6b4a1dedd35b0c63b HEA cluster_0ad089ff759d4cd5bc7cbe46ad174481 HEA b26647b934f3459abc34ef6d16164774 0 dfbf56519f784ac997f79195ad879e20 RX(x₀) b26647b934f3459abc34ef6d16164774--dfbf56519f784ac997f79195ad879e20 acab9b5e6c6c4cd0b1a8a650eb8aff5b 1 b628078078e14a11ae25c41cdb98b216 RX(theta₀₀) dfbf56519f784ac997f79195ad879e20--b628078078e14a11ae25c41cdb98b216 1080d3114e894c169d3c6ce68a74d406 RY(theta₀₄) b628078078e14a11ae25c41cdb98b216--1080d3114e894c169d3c6ce68a74d406 420a8d70b31d4708979d1e74e768d0c0 RX(theta₀₈) 1080d3114e894c169d3c6ce68a74d406--420a8d70b31d4708979d1e74e768d0c0 746b72315f0245e6bfdece5bcaeab92e 420a8d70b31d4708979d1e74e768d0c0--746b72315f0245e6bfdece5bcaeab92e 467d7ec525794434a7f7fca3c9f515a0 746b72315f0245e6bfdece5bcaeab92e--467d7ec525794434a7f7fca3c9f515a0 b5649115193d4ee3915b841a98189597 RX(x₀) 467d7ec525794434a7f7fca3c9f515a0--b5649115193d4ee3915b841a98189597 607966122a9a49f4835fbdfe142e4b7c RX(theta₁₀) b5649115193d4ee3915b841a98189597--607966122a9a49f4835fbdfe142e4b7c eba3e33474fb4805825c5028d39a5002 RY(theta₁₄) 607966122a9a49f4835fbdfe142e4b7c--eba3e33474fb4805825c5028d39a5002 0dd6fcb6031a4a518986db1229f95de4 RX(theta₁₈) eba3e33474fb4805825c5028d39a5002--0dd6fcb6031a4a518986db1229f95de4 d2d9e5b45b9840f0b81b252de14f5389 0dd6fcb6031a4a518986db1229f95de4--d2d9e5b45b9840f0b81b252de14f5389 af5c40b72a284fafa3965c3268dbf4a6 d2d9e5b45b9840f0b81b252de14f5389--af5c40b72a284fafa3965c3268dbf4a6 cbf562c8e0384e269a81f391113491c7 RX(x₀) af5c40b72a284fafa3965c3268dbf4a6--cbf562c8e0384e269a81f391113491c7 a2ac4fd09c264dc482867b11f8955f32 RX(theta₂₀) cbf562c8e0384e269a81f391113491c7--a2ac4fd09c264dc482867b11f8955f32 fa2123a35664468d990241dfac1fc4fe RY(theta₂₄) a2ac4fd09c264dc482867b11f8955f32--fa2123a35664468d990241dfac1fc4fe 2af5ed1010f248c5acf385d7345469bf RX(theta₂₈) fa2123a35664468d990241dfac1fc4fe--2af5ed1010f248c5acf385d7345469bf ed7b071058504c10924bd56ce31d400a 2af5ed1010f248c5acf385d7345469bf--ed7b071058504c10924bd56ce31d400a ac92dc15a3b74978a79aa1606e291adc ed7b071058504c10924bd56ce31d400a--ac92dc15a3b74978a79aa1606e291adc fd4cfd8fd053413899bd1a97bd5a1891 Z ac92dc15a3b74978a79aa1606e291adc--fd4cfd8fd053413899bd1a97bd5a1891 d036cd5817b1497f8507ab3742275eb3 fd4cfd8fd053413899bd1a97bd5a1891--d036cd5817b1497f8507ab3742275eb3 bfa3d04baf274ee094056c273796e8c0 b47595787978471eb310b830c9ca6a12 RX(x₁) acab9b5e6c6c4cd0b1a8a650eb8aff5b--b47595787978471eb310b830c9ca6a12 13ef4fa0da034c6a96bd5277edd2b7a3 2 3b612a6ae8324b00b2ee05969cf153be RX(theta₀₁) b47595787978471eb310b830c9ca6a12--3b612a6ae8324b00b2ee05969cf153be 8f7a558a448f4a6ebf4c0df970e572e5 RY(theta₀₅) 3b612a6ae8324b00b2ee05969cf153be--8f7a558a448f4a6ebf4c0df970e572e5 f486d3aa43684b808700711d79f19dc4 RX(theta₀₉) 8f7a558a448f4a6ebf4c0df970e572e5--f486d3aa43684b808700711d79f19dc4 0fd498c40e6a49968b727b6caf19abbd X f486d3aa43684b808700711d79f19dc4--0fd498c40e6a49968b727b6caf19abbd 0fd498c40e6a49968b727b6caf19abbd--746b72315f0245e6bfdece5bcaeab92e dba17f81ea754d8f93d4e05a445f1276 0fd498c40e6a49968b727b6caf19abbd--dba17f81ea754d8f93d4e05a445f1276 de3785009912432ab5e8bf7d4bd9a285 RX(x₁) dba17f81ea754d8f93d4e05a445f1276--de3785009912432ab5e8bf7d4bd9a285 e05104567b2449678f9402d7290ed74d RX(theta₁₁) de3785009912432ab5e8bf7d4bd9a285--e05104567b2449678f9402d7290ed74d 60321051a02644c097244157ed6c2124 RY(theta₁₅) e05104567b2449678f9402d7290ed74d--60321051a02644c097244157ed6c2124 a8ffa82e279742f5aa89d4292d4654f8 RX(theta₁₉) 60321051a02644c097244157ed6c2124--a8ffa82e279742f5aa89d4292d4654f8 fa8bd10a76bf4b288eb3bbf318223348 X a8ffa82e279742f5aa89d4292d4654f8--fa8bd10a76bf4b288eb3bbf318223348 fa8bd10a76bf4b288eb3bbf318223348--d2d9e5b45b9840f0b81b252de14f5389 e199f8170521469daf7635302d89da8f fa8bd10a76bf4b288eb3bbf318223348--e199f8170521469daf7635302d89da8f 03cf49075fcd4c3fbfb0350655cca1ca RX(x₁) e199f8170521469daf7635302d89da8f--03cf49075fcd4c3fbfb0350655cca1ca 55f058d98d1f4caabc5778ae7059532a RX(theta₂₁) 03cf49075fcd4c3fbfb0350655cca1ca--55f058d98d1f4caabc5778ae7059532a ed8aa5c9209649bca64b057fd660d6f5 RY(theta₂₅) 55f058d98d1f4caabc5778ae7059532a--ed8aa5c9209649bca64b057fd660d6f5 882b9bac7ff44560a023d7342ed130e2 RX(theta₂₉) ed8aa5c9209649bca64b057fd660d6f5--882b9bac7ff44560a023d7342ed130e2 61a252c8a22c408bb474a7d67654c7bb X 882b9bac7ff44560a023d7342ed130e2--61a252c8a22c408bb474a7d67654c7bb 61a252c8a22c408bb474a7d67654c7bb--ed7b071058504c10924bd56ce31d400a 477b2051e1d44a2ca13aecdd82b178ef 61a252c8a22c408bb474a7d67654c7bb--477b2051e1d44a2ca13aecdd82b178ef 5804fa20a48e4777add588de1d168eee 477b2051e1d44a2ca13aecdd82b178ef--5804fa20a48e4777add588de1d168eee 5804fa20a48e4777add588de1d168eee--bfa3d04baf274ee094056c273796e8c0 09744697c0c44c1aad0610ee4b2b1a16 e26fc94d201c4cc4a3e5e54f9a927750 RX(x₂) 13ef4fa0da034c6a96bd5277edd2b7a3--e26fc94d201c4cc4a3e5e54f9a927750 93523748311e42d8a28471cc6df33e9b 3 6ea49af6ca20471b9636e531f80aef3b RX(theta₀₂) e26fc94d201c4cc4a3e5e54f9a927750--6ea49af6ca20471b9636e531f80aef3b 35987dacb63442d3be542a595d14d5d9 RY(theta₀₆) 6ea49af6ca20471b9636e531f80aef3b--35987dacb63442d3be542a595d14d5d9 d446da227d864a6ea016881056f865d5 RX(theta₀₁₀) 35987dacb63442d3be542a595d14d5d9--d446da227d864a6ea016881056f865d5 3c159608101d4541b444ce20d7d491b5 d446da227d864a6ea016881056f865d5--3c159608101d4541b444ce20d7d491b5 536a773c516b4ffcaba57e5e32c67817 X 3c159608101d4541b444ce20d7d491b5--536a773c516b4ffcaba57e5e32c67817 536a773c516b4ffcaba57e5e32c67817--dba17f81ea754d8f93d4e05a445f1276 9cbeaf399f6b48bebe34be69f51aaab3 RX(x₂) 536a773c516b4ffcaba57e5e32c67817--9cbeaf399f6b48bebe34be69f51aaab3 2ed471500dad400aaf3a4c2311930b7e RX(theta₁₂) 9cbeaf399f6b48bebe34be69f51aaab3--2ed471500dad400aaf3a4c2311930b7e c8431a8a740a4bd5bdbaa147a5db662f RY(theta₁₆) 2ed471500dad400aaf3a4c2311930b7e--c8431a8a740a4bd5bdbaa147a5db662f 7ff7b358913a4148b2fbb5920f13b1c4 RX(theta₁₁₀) c8431a8a740a4bd5bdbaa147a5db662f--7ff7b358913a4148b2fbb5920f13b1c4 abde8c59c62d4ebeb65241e316bb9fa9 7ff7b358913a4148b2fbb5920f13b1c4--abde8c59c62d4ebeb65241e316bb9fa9 7bece7363e534d6e9336650c6ddeb287 X abde8c59c62d4ebeb65241e316bb9fa9--7bece7363e534d6e9336650c6ddeb287 7bece7363e534d6e9336650c6ddeb287--e199f8170521469daf7635302d89da8f f6163e4a066d4bb892073f0ac561da93 RX(x₂) 7bece7363e534d6e9336650c6ddeb287--f6163e4a066d4bb892073f0ac561da93 14bf1b76022e4891881b6d3ece0948aa RX(theta₂₂) f6163e4a066d4bb892073f0ac561da93--14bf1b76022e4891881b6d3ece0948aa ef6d846e14674194a4d50e4e1f09bcda RY(theta₂₆) 14bf1b76022e4891881b6d3ece0948aa--ef6d846e14674194a4d50e4e1f09bcda d9a168e10fc04b1496017247d09472ea RX(theta₂₁₀) ef6d846e14674194a4d50e4e1f09bcda--d9a168e10fc04b1496017247d09472ea 9fe9b1fdbc164914bdd963c88489b703 d9a168e10fc04b1496017247d09472ea--9fe9b1fdbc164914bdd963c88489b703 95397166c70f42c69d125df3a706e2b4 X 9fe9b1fdbc164914bdd963c88489b703--95397166c70f42c69d125df3a706e2b4 95397166c70f42c69d125df3a706e2b4--477b2051e1d44a2ca13aecdd82b178ef b60abd46c94f48d79e0463398488fea9 95397166c70f42c69d125df3a706e2b4--b60abd46c94f48d79e0463398488fea9 b60abd46c94f48d79e0463398488fea9--09744697c0c44c1aad0610ee4b2b1a16 c5fbdbb7df604da5a075a0173402d14c 9890d0b5e77a472dbdd0a0625dbd70cf RX(x₃) 93523748311e42d8a28471cc6df33e9b--9890d0b5e77a472dbdd0a0625dbd70cf 41f7c7adcf874d25bb8f7108413d3297 RX(theta₀₃) 9890d0b5e77a472dbdd0a0625dbd70cf--41f7c7adcf874d25bb8f7108413d3297 cd710df9557b4fc19677546e90eefdff RY(theta₀₇) 41f7c7adcf874d25bb8f7108413d3297--cd710df9557b4fc19677546e90eefdff f6665c8cebd5447aaa4778ff123b7b66 RX(theta₀₁₁) cd710df9557b4fc19677546e90eefdff--f6665c8cebd5447aaa4778ff123b7b66 ac32cd465ef94eb5bfdb7cb0ed0bd9cb X f6665c8cebd5447aaa4778ff123b7b66--ac32cd465ef94eb5bfdb7cb0ed0bd9cb ac32cd465ef94eb5bfdb7cb0ed0bd9cb--3c159608101d4541b444ce20d7d491b5 cf26576c3ec34dd3bd5a308805c565f3 ac32cd465ef94eb5bfdb7cb0ed0bd9cb--cf26576c3ec34dd3bd5a308805c565f3 2fc3d7923a7e4e9bbf9aa01bbb129c1e RX(x₃) cf26576c3ec34dd3bd5a308805c565f3--2fc3d7923a7e4e9bbf9aa01bbb129c1e 64ae75bc16084249a32c7e54aff17de8 RX(theta₁₃) 2fc3d7923a7e4e9bbf9aa01bbb129c1e--64ae75bc16084249a32c7e54aff17de8 2969263868a6484582ecf586470c2251 RY(theta₁₇) 64ae75bc16084249a32c7e54aff17de8--2969263868a6484582ecf586470c2251 9a76c7a38055430aa13a7dd14c4e6177 RX(theta₁₁₁) 2969263868a6484582ecf586470c2251--9a76c7a38055430aa13a7dd14c4e6177 2f7edc1a9a9f422298b8f519b31f254f X 9a76c7a38055430aa13a7dd14c4e6177--2f7edc1a9a9f422298b8f519b31f254f 2f7edc1a9a9f422298b8f519b31f254f--abde8c59c62d4ebeb65241e316bb9fa9 de118637a35240cb9e0b7fa6acfa0e2b 2f7edc1a9a9f422298b8f519b31f254f--de118637a35240cb9e0b7fa6acfa0e2b 51a26f4c393f4c8fa36b7ff573b848e9 RX(x₃) de118637a35240cb9e0b7fa6acfa0e2b--51a26f4c393f4c8fa36b7ff573b848e9 361e6d3282994df9a58d2e843187ebe0 RX(theta₂₃) 51a26f4c393f4c8fa36b7ff573b848e9--361e6d3282994df9a58d2e843187ebe0 686fabdb10f14506a9d4202f989569ee RY(theta₂₇) 361e6d3282994df9a58d2e843187ebe0--686fabdb10f14506a9d4202f989569ee 99d0681933d04351949c1c8057434a00 RX(theta₂₁₁) 686fabdb10f14506a9d4202f989569ee--99d0681933d04351949c1c8057434a00 f759cba3e7bc4e5e84c3b683f899b295 X 99d0681933d04351949c1c8057434a00--f759cba3e7bc4e5e84c3b683f899b295 f759cba3e7bc4e5e84c3b683f899b295--9fe9b1fdbc164914bdd963c88489b703 34ecdd3e40c14c28b87bd45582ad493e f759cba3e7bc4e5e84c3b683f899b295--34ecdd3e40c14c28b87bd45582ad493e 93a4697f9d43442f89d02ab2fa7848f4 34ecdd3e40c14c28b87bd45582ad493e--93a4697f9d43442f89d02ab2fa7848f4 93a4697f9d43442f89d02ab2fa7848f4--c5fbdbb7df604da5a075a0173402d14c

Training

Then we can set up the training part:

opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

def cross_entropy(model: nn.Module, data: Tensor) -> tuple[Tensor, dict]:
    x, y = data
    out = model(x)
    loss = criterion(out, y)
    return loss, {}

train_config = TrainConfig(max_iter=n_epochs, print_every=10, create_subfolder_per_run=True)
Trainer.set_use_grad(True)
trainer = Trainer(model=model, optimizer=opt, config=train_config, loss_fn=cross_entropy)


res_train = trainer.fit(dataloader)

Inference

Finally, we can apply our model on the test set and check the score.

X_test, y_test = dataset.X_test, dataset.y_test
preds_test = torch.argmax(torch.softmax(model(X_test), dim=1), dim=1)
accuracy_test = (preds_test == y_test).type(torch.float32).mean()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005