Skip to content

Classification with QNN

In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.

Dataset

We will use the Iris dataset separated into training and testing sets. The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica). When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).

import random

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from qadence import QNN, RX, FeatureParameter, QuantumCircuit, Z, chain, hea, kron
from qadence.ml_tools import TrainConfig, Trainer

class IrisDataset(Dataset):
    """The Iris dataset split into a training set and a test set.

    A StandardScaler is applied prior to applying models.
    """

    def __init__(self):
        X, y = load_iris(return_X_y=True)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

        self.scaler = StandardScaler()
        self.scaler.fit(X_train)
        self.X = torch.tensor(self.scaler.transform(X_train), requires_grad=False)
        self.y = torch.tensor(y_train, requires_grad=False)

        self.X_test = torch.tensor(self.scaler.transform(X_test), requires_grad=False)
        self.y_test = torch.tensor(y_test, requires_grad=False)

    def __getitem__(self, index) -> tuple[Tensor, Tensor]:
        return self.X[index], self.y[index]

    def __len__(self) -> int:
        return len(self.y)

n_features = 4  # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset()

dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

Hybrid QNN

We set up the QNN part composed of multiple feature map layers, each followed by a variational layer. The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components. The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\). Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\), \(l = W * o + b\). To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\). Note softmax is not applied during training with the cross-entropy loss.

feature_parameters = [FeatureParameter(f"x_{i}") for i in range(n_features)]
fm_layer = RX(0, feature_parameters[0])
for q in range(1, n_features):
    fm_layer = kron(fm_layer, RX(q, feature_parameters[q]))

ansatz_layers = [
    hea(n_qubits=n_features, depth=1, param_prefix=f"theta_{layer}")
    for layer in range(n_layers)
]
blocks = chain(fm_layer, ansatz_layers[0])
for layer in range(1, n_layers):
    blocks = chain(blocks, fm_layer, ansatz_layers[layer])

qc = QuantumCircuit(n_features, blocks)
qnn = QNN(circuit=qc, observable=Z(0), inputs=[f"x_{i}" for i in range(n_features)])
model = nn.Sequential(qnn, nn.Linear(1, n_neurons_final_linear_layer))

Below is a visualization of the QNN:


%3 cluster_83e8037eae3f4fd28bbad32c526b7af4 HEA cluster_40740bc78131474d8871540fbb711d1b Obs. cluster_55f68cf903304ed6a2ebee5fc84a8852 HEA cluster_d3900916aa13437f8bc6a5acec213d00 HEA 21a18f1d9eac487fab534fc463b889ca 0 3965a1192aab492cbcece7e08a08ec26 RX(x₀) 21a18f1d9eac487fab534fc463b889ca--3965a1192aab492cbcece7e08a08ec26 835c767069134f0686ec9546ae8c0f47 1 71ecd301bbd541218f5dd62ed055e982 RX(theta₀₀) 3965a1192aab492cbcece7e08a08ec26--71ecd301bbd541218f5dd62ed055e982 59277d88ee56417c88c30923bcfbde36 RY(theta₀₄) 71ecd301bbd541218f5dd62ed055e982--59277d88ee56417c88c30923bcfbde36 959759007b8d4547918b5aa77831e4a4 RX(theta₀₈) 59277d88ee56417c88c30923bcfbde36--959759007b8d4547918b5aa77831e4a4 707a5acb78b8454b960aca7de9147562 959759007b8d4547918b5aa77831e4a4--707a5acb78b8454b960aca7de9147562 14ed0e55202d44a4bef13a4338e6e37f 707a5acb78b8454b960aca7de9147562--14ed0e55202d44a4bef13a4338e6e37f 344b77ab75364fefa776aedcfd925ce8 RX(x₀) 14ed0e55202d44a4bef13a4338e6e37f--344b77ab75364fefa776aedcfd925ce8 266fdd4a36e74f09a331a8234d5815ec RX(theta₁₀) 344b77ab75364fefa776aedcfd925ce8--266fdd4a36e74f09a331a8234d5815ec 0ca8d897e9f14eb7ad6ddbb2d1f975cc RY(theta₁₄) 266fdd4a36e74f09a331a8234d5815ec--0ca8d897e9f14eb7ad6ddbb2d1f975cc 868c0dd349a74bb3a1e85c42f1c7297a RX(theta₁₈) 0ca8d897e9f14eb7ad6ddbb2d1f975cc--868c0dd349a74bb3a1e85c42f1c7297a 28b25f1950914dbbaf44a9320526dea8 868c0dd349a74bb3a1e85c42f1c7297a--28b25f1950914dbbaf44a9320526dea8 0a4ce9e5bb394b0eb19b581d4fd06566 28b25f1950914dbbaf44a9320526dea8--0a4ce9e5bb394b0eb19b581d4fd06566 98bb6817ffe647f28a5aa20b5dee9ae3 RX(x₀) 0a4ce9e5bb394b0eb19b581d4fd06566--98bb6817ffe647f28a5aa20b5dee9ae3 d9954a4940f247f1a593768641300e7b RX(theta₂₀) 98bb6817ffe647f28a5aa20b5dee9ae3--d9954a4940f247f1a593768641300e7b 017263321eb3445a9064816e3da8dfee RY(theta₂₄) d9954a4940f247f1a593768641300e7b--017263321eb3445a9064816e3da8dfee 16df1e0de80d4a5a81467179c87b1de3 RX(theta₂₈) 017263321eb3445a9064816e3da8dfee--16df1e0de80d4a5a81467179c87b1de3 2faf6ce1b27a4305aba0cd8e27f28ecc 16df1e0de80d4a5a81467179c87b1de3--2faf6ce1b27a4305aba0cd8e27f28ecc 0cb7ba36d0ca4f01a02ed84f6ce97a4c 2faf6ce1b27a4305aba0cd8e27f28ecc--0cb7ba36d0ca4f01a02ed84f6ce97a4c ed12b9f295954f8ebc6481adb98b904a Z 0cb7ba36d0ca4f01a02ed84f6ce97a4c--ed12b9f295954f8ebc6481adb98b904a a3f005223c20498caf21b71f2fb28ac4 ed12b9f295954f8ebc6481adb98b904a--a3f005223c20498caf21b71f2fb28ac4 302300ca054543e5a3b934fe76902b52 b095b1a196a64b2ebc555444db76a9cc RX(x₁) 835c767069134f0686ec9546ae8c0f47--b095b1a196a64b2ebc555444db76a9cc 9de7ba96d88949028354adf4f44578bb 2 a9b155e2492741c99a558d2894d54e26 RX(theta₀₁) b095b1a196a64b2ebc555444db76a9cc--a9b155e2492741c99a558d2894d54e26 977ad60838ef46e98611a245f720c2e1 RY(theta₀₅) a9b155e2492741c99a558d2894d54e26--977ad60838ef46e98611a245f720c2e1 a5046ac182c940a2a0edb1ca4181a849 RX(theta₀₉) 977ad60838ef46e98611a245f720c2e1--a5046ac182c940a2a0edb1ca4181a849 ccbeae8fd628499ba1efeab863f76397 X a5046ac182c940a2a0edb1ca4181a849--ccbeae8fd628499ba1efeab863f76397 ccbeae8fd628499ba1efeab863f76397--707a5acb78b8454b960aca7de9147562 f9629250272b4906b3e30bf1bdb77bb3 ccbeae8fd628499ba1efeab863f76397--f9629250272b4906b3e30bf1bdb77bb3 258e34bf98da4a319865018527a63b64 RX(x₁) f9629250272b4906b3e30bf1bdb77bb3--258e34bf98da4a319865018527a63b64 52de5c3612a84757b2cb63c14a28da8d RX(theta₁₁) 258e34bf98da4a319865018527a63b64--52de5c3612a84757b2cb63c14a28da8d 6e85839c8d26462aa5c576421d25ead6 RY(theta₁₅) 52de5c3612a84757b2cb63c14a28da8d--6e85839c8d26462aa5c576421d25ead6 94f5c87a504f49f683386a3264636e85 RX(theta₁₉) 6e85839c8d26462aa5c576421d25ead6--94f5c87a504f49f683386a3264636e85 ba1e2cf536c947de9db683f81eb42aeb X 94f5c87a504f49f683386a3264636e85--ba1e2cf536c947de9db683f81eb42aeb ba1e2cf536c947de9db683f81eb42aeb--28b25f1950914dbbaf44a9320526dea8 07c737a6294b4d93882d95ed200a9d85 ba1e2cf536c947de9db683f81eb42aeb--07c737a6294b4d93882d95ed200a9d85 2b3b2e801ff04b31bc94eb450dbcef54 RX(x₁) 07c737a6294b4d93882d95ed200a9d85--2b3b2e801ff04b31bc94eb450dbcef54 3701b4dc62a64d708ccb6864b8bb9b92 RX(theta₂₁) 2b3b2e801ff04b31bc94eb450dbcef54--3701b4dc62a64d708ccb6864b8bb9b92 3c4b9c4209f743ba9d82fab48c618f25 RY(theta₂₅) 3701b4dc62a64d708ccb6864b8bb9b92--3c4b9c4209f743ba9d82fab48c618f25 1618ad6da8674e76a0f5c446ba4cd5c4 RX(theta₂₉) 3c4b9c4209f743ba9d82fab48c618f25--1618ad6da8674e76a0f5c446ba4cd5c4 d003234e1d80462eadfd9a7a694842fd X 1618ad6da8674e76a0f5c446ba4cd5c4--d003234e1d80462eadfd9a7a694842fd d003234e1d80462eadfd9a7a694842fd--2faf6ce1b27a4305aba0cd8e27f28ecc ee687190c24c4535bc11d6bd71f9a9df d003234e1d80462eadfd9a7a694842fd--ee687190c24c4535bc11d6bd71f9a9df 8dcb8f999a2d499ebb87132a852524fc ee687190c24c4535bc11d6bd71f9a9df--8dcb8f999a2d499ebb87132a852524fc 8dcb8f999a2d499ebb87132a852524fc--302300ca054543e5a3b934fe76902b52 9d584dcc0bf549ac9554627c2c7caeb5 1c7d43cfce6e4bb18f147b7235211780 RX(x₂) 9de7ba96d88949028354adf4f44578bb--1c7d43cfce6e4bb18f147b7235211780 1228dd137fde4a92979033552b30904c 3 cb071ddb85cf4616926f29676b996b1f RX(theta₀₂) 1c7d43cfce6e4bb18f147b7235211780--cb071ddb85cf4616926f29676b996b1f b958646cb88d4448bfacabfef08787fd RY(theta₀₆) cb071ddb85cf4616926f29676b996b1f--b958646cb88d4448bfacabfef08787fd 18ef664a7b044b7a85458c03c5958fae RX(theta₀₁₀) b958646cb88d4448bfacabfef08787fd--18ef664a7b044b7a85458c03c5958fae 95dfcb569f5140b28c6f221896760d7e 18ef664a7b044b7a85458c03c5958fae--95dfcb569f5140b28c6f221896760d7e 26f58edb1f144455ab7270b1e554e8e2 X 95dfcb569f5140b28c6f221896760d7e--26f58edb1f144455ab7270b1e554e8e2 26f58edb1f144455ab7270b1e554e8e2--f9629250272b4906b3e30bf1bdb77bb3 9f5f3e78109442ffaed4c62a539d2b72 RX(x₂) 26f58edb1f144455ab7270b1e554e8e2--9f5f3e78109442ffaed4c62a539d2b72 96de85ccf9c0429cb7bfc6252e2b3b6d RX(theta₁₂) 9f5f3e78109442ffaed4c62a539d2b72--96de85ccf9c0429cb7bfc6252e2b3b6d 21a54944e1be469aabca7cc448f67f62 RY(theta₁₆) 96de85ccf9c0429cb7bfc6252e2b3b6d--21a54944e1be469aabca7cc448f67f62 ddc510cbd9d84fd39320645a9d23741b RX(theta₁₁₀) 21a54944e1be469aabca7cc448f67f62--ddc510cbd9d84fd39320645a9d23741b ec9afa467e4c4c4e9b52d1e32674ec70 ddc510cbd9d84fd39320645a9d23741b--ec9afa467e4c4c4e9b52d1e32674ec70 2f46fd99685f4aa8ad02a19e9180b57e X ec9afa467e4c4c4e9b52d1e32674ec70--2f46fd99685f4aa8ad02a19e9180b57e 2f46fd99685f4aa8ad02a19e9180b57e--07c737a6294b4d93882d95ed200a9d85 5e51e574c7e34911a0d66709c672ef12 RX(x₂) 2f46fd99685f4aa8ad02a19e9180b57e--5e51e574c7e34911a0d66709c672ef12 7d72c965d1f243a7ada9275d851fea1e RX(theta₂₂) 5e51e574c7e34911a0d66709c672ef12--7d72c965d1f243a7ada9275d851fea1e cdc21ba0d3be4b2499761f13901777a6 RY(theta₂₆) 7d72c965d1f243a7ada9275d851fea1e--cdc21ba0d3be4b2499761f13901777a6 13aa3da8b19d4ffba42d95e7885e0df1 RX(theta₂₁₀) cdc21ba0d3be4b2499761f13901777a6--13aa3da8b19d4ffba42d95e7885e0df1 b7001b1d36fa497ab9683e40a697ce7d 13aa3da8b19d4ffba42d95e7885e0df1--b7001b1d36fa497ab9683e40a697ce7d 5869caa74447486a85e5e5fee64686ce X b7001b1d36fa497ab9683e40a697ce7d--5869caa74447486a85e5e5fee64686ce 5869caa74447486a85e5e5fee64686ce--ee687190c24c4535bc11d6bd71f9a9df f94201fe82fc45b483d0b393bd82d5ad 5869caa74447486a85e5e5fee64686ce--f94201fe82fc45b483d0b393bd82d5ad f94201fe82fc45b483d0b393bd82d5ad--9d584dcc0bf549ac9554627c2c7caeb5 fe9df131bd0f47eca34d64aa5159ac73 9da0c4fa5ed844b58401105bf19193f9 RX(x₃) 1228dd137fde4a92979033552b30904c--9da0c4fa5ed844b58401105bf19193f9 827a05c9975546dba42c32b2893fe597 RX(theta₀₃) 9da0c4fa5ed844b58401105bf19193f9--827a05c9975546dba42c32b2893fe597 9415f62a0837429da3e1afa60c2ef6b9 RY(theta₀₇) 827a05c9975546dba42c32b2893fe597--9415f62a0837429da3e1afa60c2ef6b9 e3d9e93299d84aadb57fde37f0109a37 RX(theta₀₁₁) 9415f62a0837429da3e1afa60c2ef6b9--e3d9e93299d84aadb57fde37f0109a37 7deab2ab92fb45838baee53cbbd69b36 X e3d9e93299d84aadb57fde37f0109a37--7deab2ab92fb45838baee53cbbd69b36 7deab2ab92fb45838baee53cbbd69b36--95dfcb569f5140b28c6f221896760d7e 42e9009dda45479ab3e661c87057186d 7deab2ab92fb45838baee53cbbd69b36--42e9009dda45479ab3e661c87057186d ee8e6421e6764e2c9d21e81a18a12921 RX(x₃) 42e9009dda45479ab3e661c87057186d--ee8e6421e6764e2c9d21e81a18a12921 e23f2cea35f44cf9bf5e45b4dc7e13d2 RX(theta₁₃) ee8e6421e6764e2c9d21e81a18a12921--e23f2cea35f44cf9bf5e45b4dc7e13d2 91ca9fbcc4b246ee89513a03c1400a5d RY(theta₁₇) e23f2cea35f44cf9bf5e45b4dc7e13d2--91ca9fbcc4b246ee89513a03c1400a5d a9e964a512d04787af64724e9184f5e9 RX(theta₁₁₁) 91ca9fbcc4b246ee89513a03c1400a5d--a9e964a512d04787af64724e9184f5e9 7d67ee20469f4fe1bb0f7c5f58f38ab5 X a9e964a512d04787af64724e9184f5e9--7d67ee20469f4fe1bb0f7c5f58f38ab5 7d67ee20469f4fe1bb0f7c5f58f38ab5--ec9afa467e4c4c4e9b52d1e32674ec70 a34c68aef82f4cd2aaf3d2cf93241cd2 7d67ee20469f4fe1bb0f7c5f58f38ab5--a34c68aef82f4cd2aaf3d2cf93241cd2 50e525762aeb47d29fc3edde3fb4b25c RX(x₃) a34c68aef82f4cd2aaf3d2cf93241cd2--50e525762aeb47d29fc3edde3fb4b25c 6d8650bed8fb49bfa4494a98d77670ec RX(theta₂₃) 50e525762aeb47d29fc3edde3fb4b25c--6d8650bed8fb49bfa4494a98d77670ec dd053f2cab1a4adfa7409d032ce92044 RY(theta₂₇) 6d8650bed8fb49bfa4494a98d77670ec--dd053f2cab1a4adfa7409d032ce92044 40d452f07c894310b734b1bef6f67ee4 RX(theta₂₁₁) dd053f2cab1a4adfa7409d032ce92044--40d452f07c894310b734b1bef6f67ee4 b86ac3d7b3394847856338b21af8c381 X 40d452f07c894310b734b1bef6f67ee4--b86ac3d7b3394847856338b21af8c381 b86ac3d7b3394847856338b21af8c381--b7001b1d36fa497ab9683e40a697ce7d cc53a6491d414bf7b39d8831e77d6752 b86ac3d7b3394847856338b21af8c381--cc53a6491d414bf7b39d8831e77d6752 f821ac849f4041fc91f8a1c394df313a cc53a6491d414bf7b39d8831e77d6752--f821ac849f4041fc91f8a1c394df313a f821ac849f4041fc91f8a1c394df313a--fe9df131bd0f47eca34d64aa5159ac73

Training

Then we can set up the training part:

opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

def cross_entropy(model: nn.Module, data: Tensor) -> tuple[Tensor, dict]:
    x, y = data
    out = model(x)
    loss = criterion(out, y)
    return loss, {}

train_config = TrainConfig(max_iter=n_epochs, print_every=10, create_subfolder_per_run=True)
Trainer.set_use_grad(True)
trainer = Trainer(model=model, optimizer=opt, config=train_config, loss_fn=cross_entropy)


res_train = trainer.fit(dataloader)

Inference

Finally, we can apply our model on the test set and check the score.

X_test, y_test = dataset.X_test, dataset.y_test
preds_test = torch.argmax(torch.softmax(model(X_test), dim=1), dim=1)
accuracy_test = (preds_test == y_test).type(torch.float32).mean()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005