Skip to content

Classification with QNN

In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.

Dataset

We will use the Iris dataset separated into training and testing sets. The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica). When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).

import random

import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from qadence import QNN, RX, FeatureParameter, QuantumCircuit, Z, chain, hea, kron
from qadence.ml_tools import TrainConfig, Trainer

class IrisDataset(Dataset):
    """The Iris dataset split into a training set and a test set.

    A StandardScaler is applied prior to applying models.
    """

    def __init__(self):
        X, y = load_iris(return_X_y=True)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

        self.scaler = StandardScaler()
        self.scaler.fit(X_train)
        self.X = torch.tensor(self.scaler.transform(X_train), requires_grad=False)
        self.y = torch.tensor(y_train, requires_grad=False)

        self.X_test = torch.tensor(self.scaler.transform(X_test), requires_grad=False)
        self.y_test = torch.tensor(y_test, requires_grad=False)

    def __getitem__(self, index) -> tuple[Tensor, Tensor]:
        return self.X[index], self.y[index]

    def __len__(self) -> int:
        return len(self.y)

n_features = 4  # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset()

dataloader = DataLoader(dataset, batch_size=20, shuffle=True)

Hybrid QNN

We set up the QNN part composed of multiple feature map layers, each followed by a variational layer. The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components. The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\). Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\), \(l = W * o + b\). To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\). Note softmax is not applied during training with the cross-entropy loss.

feature_parameters = [FeatureParameter(f"x_{i}") for i in range(n_features)]
fm_layer = RX(0, feature_parameters[0])
for q in range(1, n_features):
    fm_layer = kron(fm_layer, RX(q, feature_parameters[q]))

ansatz_layers = [
    hea(n_qubits=n_features, depth=1, param_prefix=f"theta_{layer}")
    for layer in range(n_layers)
]
blocks = chain(fm_layer, ansatz_layers[0])
for layer in range(1, n_layers):
    blocks = chain(blocks, fm_layer, ansatz_layers[layer])

qc = QuantumCircuit(n_features, blocks)
qnn = QNN(circuit=qc, observable=Z(0), inputs=[f"x_{i}" for i in range(n_features)])
model = nn.Sequential(qnn, nn.Linear(1, n_neurons_final_linear_layer))

Below is a visualization of the QNN:


%3 cluster_be58adabc4864f4280b4c3bbc277edc5 HEA cluster_76b34ac672434ec4ab2a26296d7b1770 Obs. cluster_a04045b109d74b7ba82cd2b8aef369f6 HEA cluster_b55f3102e5474f8da311b227fd0b55b5 HEA 78d0c77e44034b7397e3c23a68b2f887 0 7664b9740713433fb1f9251c921845ca RX(x₀) 78d0c77e44034b7397e3c23a68b2f887--7664b9740713433fb1f9251c921845ca eaee8157198a465391aed6ef5468d418 1 79a9c58102664ff5971a17302dad1fa5 RX(theta₀₀) 7664b9740713433fb1f9251c921845ca--79a9c58102664ff5971a17302dad1fa5 ca074658ee604d2c8579fa1b679b423b RY(theta₀₄) 79a9c58102664ff5971a17302dad1fa5--ca074658ee604d2c8579fa1b679b423b 637dbd7178734cbcac726b86fe2b028e RX(theta₀₈) ca074658ee604d2c8579fa1b679b423b--637dbd7178734cbcac726b86fe2b028e a362e394d4624c7abd0af319c283a5d4 637dbd7178734cbcac726b86fe2b028e--a362e394d4624c7abd0af319c283a5d4 96caae3cebf441c0b5a1c844ffcf1b7d a362e394d4624c7abd0af319c283a5d4--96caae3cebf441c0b5a1c844ffcf1b7d 41c220df5c8c4e538410f3e91f5f971d RX(x₀) 96caae3cebf441c0b5a1c844ffcf1b7d--41c220df5c8c4e538410f3e91f5f971d f6434eac2e18453bae477720739b7917 RX(theta₁₀) 41c220df5c8c4e538410f3e91f5f971d--f6434eac2e18453bae477720739b7917 10e454fd7b6844dbab517a06e2b130dc RY(theta₁₄) f6434eac2e18453bae477720739b7917--10e454fd7b6844dbab517a06e2b130dc 22b35817ab9144f393bb1257a1e973c4 RX(theta₁₈) 10e454fd7b6844dbab517a06e2b130dc--22b35817ab9144f393bb1257a1e973c4 565cf6c96639480d826b325a44f475a9 22b35817ab9144f393bb1257a1e973c4--565cf6c96639480d826b325a44f475a9 4e715fb985384d6bbdb88b64c6b1e310 565cf6c96639480d826b325a44f475a9--4e715fb985384d6bbdb88b64c6b1e310 a18360c8be184f749060a47dd3e3385b RX(x₀) 4e715fb985384d6bbdb88b64c6b1e310--a18360c8be184f749060a47dd3e3385b 00ec25a6518c4579aa1f2f546de1f265 RX(theta₂₀) a18360c8be184f749060a47dd3e3385b--00ec25a6518c4579aa1f2f546de1f265 d233aa25e1404f199e3247b92ca06ff6 RY(theta₂₄) 00ec25a6518c4579aa1f2f546de1f265--d233aa25e1404f199e3247b92ca06ff6 0b35d9d5a9aa452da6bb3e6584ae5dad RX(theta₂₈) d233aa25e1404f199e3247b92ca06ff6--0b35d9d5a9aa452da6bb3e6584ae5dad fe85c407a1954c5db24fd10da9d4b791 0b35d9d5a9aa452da6bb3e6584ae5dad--fe85c407a1954c5db24fd10da9d4b791 6c208355a098457487676f01aee086d4 fe85c407a1954c5db24fd10da9d4b791--6c208355a098457487676f01aee086d4 a9fbef9d89b44dd8afa9291a5468b17e Z 6c208355a098457487676f01aee086d4--a9fbef9d89b44dd8afa9291a5468b17e 361cf78ea4a944cdb8ae623237f415b2 a9fbef9d89b44dd8afa9291a5468b17e--361cf78ea4a944cdb8ae623237f415b2 a4c459f4ecee4ac09396375bbcc035e8 f41ac59491f4444f895688cee38bc40c RX(x₁) eaee8157198a465391aed6ef5468d418--f41ac59491f4444f895688cee38bc40c 09a762479fb44ade81a2fb1ee873052a 2 51699cbd85744587829bf57a5075cfbe RX(theta₀₁) f41ac59491f4444f895688cee38bc40c--51699cbd85744587829bf57a5075cfbe 2faebee1466340a4bfa8090b91a6f8f6 RY(theta₀₅) 51699cbd85744587829bf57a5075cfbe--2faebee1466340a4bfa8090b91a6f8f6 e6f0e2ce04ca41d8841362ee79b67fb3 RX(theta₀₉) 2faebee1466340a4bfa8090b91a6f8f6--e6f0e2ce04ca41d8841362ee79b67fb3 515a72262bed47ff81a0a1c27e67aeb6 X e6f0e2ce04ca41d8841362ee79b67fb3--515a72262bed47ff81a0a1c27e67aeb6 515a72262bed47ff81a0a1c27e67aeb6--a362e394d4624c7abd0af319c283a5d4 314f441dc923484085d31e0d391a73f4 515a72262bed47ff81a0a1c27e67aeb6--314f441dc923484085d31e0d391a73f4 cfec683fa3144052ba70fdda88036576 RX(x₁) 314f441dc923484085d31e0d391a73f4--cfec683fa3144052ba70fdda88036576 69d5fe9c8ecc418faa192d080ba173b1 RX(theta₁₁) cfec683fa3144052ba70fdda88036576--69d5fe9c8ecc418faa192d080ba173b1 6e37efc6d6c44641a5be25f2ffd5b427 RY(theta₁₅) 69d5fe9c8ecc418faa192d080ba173b1--6e37efc6d6c44641a5be25f2ffd5b427 68fc2565089240139461fd26160498e2 RX(theta₁₉) 6e37efc6d6c44641a5be25f2ffd5b427--68fc2565089240139461fd26160498e2 d3f05bdeea0b4db882fe791e59e8f708 X 68fc2565089240139461fd26160498e2--d3f05bdeea0b4db882fe791e59e8f708 d3f05bdeea0b4db882fe791e59e8f708--565cf6c96639480d826b325a44f475a9 04407e6ada294795bf8a171a8fc538aa d3f05bdeea0b4db882fe791e59e8f708--04407e6ada294795bf8a171a8fc538aa d7e8f06b88b145b2ab364889156d369d RX(x₁) 04407e6ada294795bf8a171a8fc538aa--d7e8f06b88b145b2ab364889156d369d ccdbc6f2b06d4621b91497012d28de32 RX(theta₂₁) d7e8f06b88b145b2ab364889156d369d--ccdbc6f2b06d4621b91497012d28de32 e8ca421c8590493389c778ea8222be2e RY(theta₂₅) ccdbc6f2b06d4621b91497012d28de32--e8ca421c8590493389c778ea8222be2e 593ea529d16945adafd754a5d34828fe RX(theta₂₉) e8ca421c8590493389c778ea8222be2e--593ea529d16945adafd754a5d34828fe d5eb2645b85d46a59ab80f86baec2a0b X 593ea529d16945adafd754a5d34828fe--d5eb2645b85d46a59ab80f86baec2a0b d5eb2645b85d46a59ab80f86baec2a0b--fe85c407a1954c5db24fd10da9d4b791 91825a8ed77e4b2b96d45541429fe70c d5eb2645b85d46a59ab80f86baec2a0b--91825a8ed77e4b2b96d45541429fe70c 40b7f095dc5a41e78d29f7059b070a0d 91825a8ed77e4b2b96d45541429fe70c--40b7f095dc5a41e78d29f7059b070a0d 40b7f095dc5a41e78d29f7059b070a0d--a4c459f4ecee4ac09396375bbcc035e8 b9d99184ad484742b9ca1557b957b200 890f796e33ce4ddb88d05fdad065c552 RX(x₂) 09a762479fb44ade81a2fb1ee873052a--890f796e33ce4ddb88d05fdad065c552 40321bb1941146bd9ac37e01c75d816c 3 07094a35b617430caa2728e9f86af2c1 RX(theta₀₂) 890f796e33ce4ddb88d05fdad065c552--07094a35b617430caa2728e9f86af2c1 365cee6b48194c7a904e98524fd491b4 RY(theta₀₆) 07094a35b617430caa2728e9f86af2c1--365cee6b48194c7a904e98524fd491b4 d8bbb8a9f0624e688146151a0e1997c5 RX(theta₀₁₀) 365cee6b48194c7a904e98524fd491b4--d8bbb8a9f0624e688146151a0e1997c5 6373ebf580d94431b2c1b2fbfe784524 d8bbb8a9f0624e688146151a0e1997c5--6373ebf580d94431b2c1b2fbfe784524 69a2fb5a5ae44daa8210d0647c0c1060 X 6373ebf580d94431b2c1b2fbfe784524--69a2fb5a5ae44daa8210d0647c0c1060 69a2fb5a5ae44daa8210d0647c0c1060--314f441dc923484085d31e0d391a73f4 29ce3c61f7e94dacb3acb7f04e1a0026 RX(x₂) 69a2fb5a5ae44daa8210d0647c0c1060--29ce3c61f7e94dacb3acb7f04e1a0026 85816f088b274167b6c7fbe5b6e15fdb RX(theta₁₂) 29ce3c61f7e94dacb3acb7f04e1a0026--85816f088b274167b6c7fbe5b6e15fdb 5f79c70169094a82943702de04741a7f RY(theta₁₆) 85816f088b274167b6c7fbe5b6e15fdb--5f79c70169094a82943702de04741a7f 9bb6710da07d490694c06b5518c04d02 RX(theta₁₁₀) 5f79c70169094a82943702de04741a7f--9bb6710da07d490694c06b5518c04d02 3d6203dc77de49ffb4a1b5fae6991ae7 9bb6710da07d490694c06b5518c04d02--3d6203dc77de49ffb4a1b5fae6991ae7 69cf5b8edcb84ae1862a1beeacc1bb57 X 3d6203dc77de49ffb4a1b5fae6991ae7--69cf5b8edcb84ae1862a1beeacc1bb57 69cf5b8edcb84ae1862a1beeacc1bb57--04407e6ada294795bf8a171a8fc538aa febad808fda840bfb167f3c2738bec9f RX(x₂) 69cf5b8edcb84ae1862a1beeacc1bb57--febad808fda840bfb167f3c2738bec9f 6083eaaf324042608ee58e4746127fd4 RX(theta₂₂) febad808fda840bfb167f3c2738bec9f--6083eaaf324042608ee58e4746127fd4 85a72702c2cf4c7d9612b649c32ab4a6 RY(theta₂₆) 6083eaaf324042608ee58e4746127fd4--85a72702c2cf4c7d9612b649c32ab4a6 1d2ce72c1e744a3689f694cbbb3a90ef RX(theta₂₁₀) 85a72702c2cf4c7d9612b649c32ab4a6--1d2ce72c1e744a3689f694cbbb3a90ef 5cf71b231f304b8fb020b9a3eae197d4 1d2ce72c1e744a3689f694cbbb3a90ef--5cf71b231f304b8fb020b9a3eae197d4 0b2734d758be465da75234e1fe532a96 X 5cf71b231f304b8fb020b9a3eae197d4--0b2734d758be465da75234e1fe532a96 0b2734d758be465da75234e1fe532a96--91825a8ed77e4b2b96d45541429fe70c 8b617d5311d643208b7006d010e5d944 0b2734d758be465da75234e1fe532a96--8b617d5311d643208b7006d010e5d944 8b617d5311d643208b7006d010e5d944--b9d99184ad484742b9ca1557b957b200 c20f0698d95a4f5ab6e4c3669174e4c1 5eed620d40dc4d6f8449f4e68fef43aa RX(x₃) 40321bb1941146bd9ac37e01c75d816c--5eed620d40dc4d6f8449f4e68fef43aa 033644304bef4a59ae16961bfae071d6 RX(theta₀₃) 5eed620d40dc4d6f8449f4e68fef43aa--033644304bef4a59ae16961bfae071d6 253108cddfcf4202bbc66af6d2ff3a44 RY(theta₀₇) 033644304bef4a59ae16961bfae071d6--253108cddfcf4202bbc66af6d2ff3a44 9322518b0887471c9203b4545aa7c59e RX(theta₀₁₁) 253108cddfcf4202bbc66af6d2ff3a44--9322518b0887471c9203b4545aa7c59e 3680b69738cf437fadd4958fbe11581c X 9322518b0887471c9203b4545aa7c59e--3680b69738cf437fadd4958fbe11581c 3680b69738cf437fadd4958fbe11581c--6373ebf580d94431b2c1b2fbfe784524 aad5acc10cc1431483bed4bc47958aa6 3680b69738cf437fadd4958fbe11581c--aad5acc10cc1431483bed4bc47958aa6 1b91d2adebc84ab2a74d41540f2ff544 RX(x₃) aad5acc10cc1431483bed4bc47958aa6--1b91d2adebc84ab2a74d41540f2ff544 4717821326024749b0435f26580d977f RX(theta₁₃) 1b91d2adebc84ab2a74d41540f2ff544--4717821326024749b0435f26580d977f 262503cebf3c454991f9980ca7962fbe RY(theta₁₇) 4717821326024749b0435f26580d977f--262503cebf3c454991f9980ca7962fbe 524e88b7d24346a78e7352a86cf174b5 RX(theta₁₁₁) 262503cebf3c454991f9980ca7962fbe--524e88b7d24346a78e7352a86cf174b5 f0598a50e9ee4b5287ea566a30e463a1 X 524e88b7d24346a78e7352a86cf174b5--f0598a50e9ee4b5287ea566a30e463a1 f0598a50e9ee4b5287ea566a30e463a1--3d6203dc77de49ffb4a1b5fae6991ae7 75eb33db13d84875b949dbe1574efd04 f0598a50e9ee4b5287ea566a30e463a1--75eb33db13d84875b949dbe1574efd04 a8d6930870f1463bbbd0f466c9d2f8dc RX(x₃) 75eb33db13d84875b949dbe1574efd04--a8d6930870f1463bbbd0f466c9d2f8dc 636330495de04fa0b2a27bdc2e6455cf RX(theta₂₃) a8d6930870f1463bbbd0f466c9d2f8dc--636330495de04fa0b2a27bdc2e6455cf 3dd10f23479441929fc3304fa6944958 RY(theta₂₇) 636330495de04fa0b2a27bdc2e6455cf--3dd10f23479441929fc3304fa6944958 0bf392b96dc54b18b34a506c26c8b843 RX(theta₂₁₁) 3dd10f23479441929fc3304fa6944958--0bf392b96dc54b18b34a506c26c8b843 b31c7a64de054b6fbe587bf749454220 X 0bf392b96dc54b18b34a506c26c8b843--b31c7a64de054b6fbe587bf749454220 b31c7a64de054b6fbe587bf749454220--5cf71b231f304b8fb020b9a3eae197d4 be8e73619b5e446ba958c3d8888f5cef b31c7a64de054b6fbe587bf749454220--be8e73619b5e446ba958c3d8888f5cef 4a9d229035684dfd8f1b150fee1294d2 be8e73619b5e446ba958c3d8888f5cef--4a9d229035684dfd8f1b150fee1294d2 4a9d229035684dfd8f1b150fee1294d2--c20f0698d95a4f5ab6e4c3669174e4c1

Training

Then we can set up the training part:

opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

def cross_entropy(model: nn.Module, data: Tensor) -> tuple[Tensor, dict]:
    x, y = data
    out = model(x)
    loss = criterion(out, y)
    return loss, {}

train_config = TrainConfig(max_iter=n_epochs, print_every=10, create_subfolder_per_run=True)
Trainer.set_use_grad(True)
trainer = Trainer(model=model, optimizer=opt, config=train_config, loss_fn=cross_entropy)


res_train = trainer.fit(dataloader)

Inference

Finally, we can apply our model on the test set and check the score.

X_test, y_test = dataset.X_test, dataset.y_test
preds_test = torch.argmax(torch.softmax(model(X_test), dim=1), dim=1)
accuracy_test = (preds_test == y_test).type(torch.float32).mean()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005