Classification with QNN
In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.
Dataset
We will use the Iris dataset separated into training and testing sets.
The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica).
When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).
import random
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader , Dataset
from qadence import QNN , RX , FeatureParameter , QuantumCircuit , Z , chain , hea , kron
from qadence.ml_tools import TrainConfig , Trainer
class IrisDataset ( Dataset ):
"""The Iris dataset split into a training set and a test set.
A StandardScaler is applied prior to applying models.
"""
def __init__ ( self ):
X , y = load_iris ( return_X_y = True )
X_train , X_test , y_train , y_test = train_test_split ( X , y , test_size = 0.33 , random_state = 42 )
self . scaler = StandardScaler ()
self . scaler . fit ( X_train )
self . X = torch . tensor ( self . scaler . transform ( X_train ), requires_grad = False )
self . y = torch . tensor ( y_train , requires_grad = False )
self . X_test = torch . tensor ( self . scaler . transform ( X_test ), requires_grad = False )
self . y_test = torch . tensor ( y_test , requires_grad = False )
def __getitem__ ( self , index ) -> tuple [ Tensor , Tensor ]:
return self . X [ index ], self . y [ index ]
def __len__ ( self ) -> int :
return len ( self . y )
n_features = 4 # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset ()
dataloader = DataLoader ( dataset , batch_size = 20 , shuffle = True )
Hybrid QNN
We set up the QNN part composed of multiple feature map layers, each followed by a variational layer.
The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components.
The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\) .
Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\) , \(l = W * o + b\) . To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\) .
Note softmax is not applied during training with the cross-entropy loss .
feature_parameters = [ FeatureParameter ( f "x_ { i } " ) for i in range ( n_features )]
fm_layer = RX ( 0 , feature_parameters [ 0 ])
for q in range ( 1 , n_features ):
fm_layer = kron ( fm_layer , RX ( q , feature_parameters [ q ]))
ansatz_layers = [
hea ( n_qubits = n_features , depth = 1 , param_prefix = f "theta_ { layer } " )
for layer in range ( n_layers )
]
blocks = chain ( fm_layer , ansatz_layers [ 0 ])
for layer in range ( 1 , n_layers ):
blocks = chain ( blocks , fm_layer , ansatz_layers [ layer ])
qc = QuantumCircuit ( n_features , blocks )
qnn = QNN ( circuit = qc , observable = Z ( 0 ), inputs = [ f "x_ { i } " for i in range ( n_features )])
model = nn . Sequential ( qnn , nn . Linear ( 1 , n_neurons_final_linear_layer ))
Below is a visualization of the QNN:
%3
cluster_be58adabc4864f4280b4c3bbc277edc5
HEA
cluster_76b34ac672434ec4ab2a26296d7b1770
Obs.
cluster_a04045b109d74b7ba82cd2b8aef369f6
HEA
cluster_b55f3102e5474f8da311b227fd0b55b5
HEA
78d0c77e44034b7397e3c23a68b2f887
0
7664b9740713433fb1f9251c921845ca
RX(x₀)
78d0c77e44034b7397e3c23a68b2f887--7664b9740713433fb1f9251c921845ca
eaee8157198a465391aed6ef5468d418
1
79a9c58102664ff5971a17302dad1fa5
RX(theta₀₀)
7664b9740713433fb1f9251c921845ca--79a9c58102664ff5971a17302dad1fa5
ca074658ee604d2c8579fa1b679b423b
RY(theta₀₄)
79a9c58102664ff5971a17302dad1fa5--ca074658ee604d2c8579fa1b679b423b
637dbd7178734cbcac726b86fe2b028e
RX(theta₀₈)
ca074658ee604d2c8579fa1b679b423b--637dbd7178734cbcac726b86fe2b028e
a362e394d4624c7abd0af319c283a5d4
637dbd7178734cbcac726b86fe2b028e--a362e394d4624c7abd0af319c283a5d4
96caae3cebf441c0b5a1c844ffcf1b7d
a362e394d4624c7abd0af319c283a5d4--96caae3cebf441c0b5a1c844ffcf1b7d
41c220df5c8c4e538410f3e91f5f971d
RX(x₀)
96caae3cebf441c0b5a1c844ffcf1b7d--41c220df5c8c4e538410f3e91f5f971d
f6434eac2e18453bae477720739b7917
RX(theta₁₀)
41c220df5c8c4e538410f3e91f5f971d--f6434eac2e18453bae477720739b7917
10e454fd7b6844dbab517a06e2b130dc
RY(theta₁₄)
f6434eac2e18453bae477720739b7917--10e454fd7b6844dbab517a06e2b130dc
22b35817ab9144f393bb1257a1e973c4
RX(theta₁₈)
10e454fd7b6844dbab517a06e2b130dc--22b35817ab9144f393bb1257a1e973c4
565cf6c96639480d826b325a44f475a9
22b35817ab9144f393bb1257a1e973c4--565cf6c96639480d826b325a44f475a9
4e715fb985384d6bbdb88b64c6b1e310
565cf6c96639480d826b325a44f475a9--4e715fb985384d6bbdb88b64c6b1e310
a18360c8be184f749060a47dd3e3385b
RX(x₀)
4e715fb985384d6bbdb88b64c6b1e310--a18360c8be184f749060a47dd3e3385b
00ec25a6518c4579aa1f2f546de1f265
RX(theta₂₀)
a18360c8be184f749060a47dd3e3385b--00ec25a6518c4579aa1f2f546de1f265
d233aa25e1404f199e3247b92ca06ff6
RY(theta₂₄)
00ec25a6518c4579aa1f2f546de1f265--d233aa25e1404f199e3247b92ca06ff6
0b35d9d5a9aa452da6bb3e6584ae5dad
RX(theta₂₈)
d233aa25e1404f199e3247b92ca06ff6--0b35d9d5a9aa452da6bb3e6584ae5dad
fe85c407a1954c5db24fd10da9d4b791
0b35d9d5a9aa452da6bb3e6584ae5dad--fe85c407a1954c5db24fd10da9d4b791
6c208355a098457487676f01aee086d4
fe85c407a1954c5db24fd10da9d4b791--6c208355a098457487676f01aee086d4
a9fbef9d89b44dd8afa9291a5468b17e
Z
6c208355a098457487676f01aee086d4--a9fbef9d89b44dd8afa9291a5468b17e
361cf78ea4a944cdb8ae623237f415b2
a9fbef9d89b44dd8afa9291a5468b17e--361cf78ea4a944cdb8ae623237f415b2
a4c459f4ecee4ac09396375bbcc035e8
f41ac59491f4444f895688cee38bc40c
RX(x₁)
eaee8157198a465391aed6ef5468d418--f41ac59491f4444f895688cee38bc40c
09a762479fb44ade81a2fb1ee873052a
2
51699cbd85744587829bf57a5075cfbe
RX(theta₀₁)
f41ac59491f4444f895688cee38bc40c--51699cbd85744587829bf57a5075cfbe
2faebee1466340a4bfa8090b91a6f8f6
RY(theta₀₅)
51699cbd85744587829bf57a5075cfbe--2faebee1466340a4bfa8090b91a6f8f6
e6f0e2ce04ca41d8841362ee79b67fb3
RX(theta₀₉)
2faebee1466340a4bfa8090b91a6f8f6--e6f0e2ce04ca41d8841362ee79b67fb3
515a72262bed47ff81a0a1c27e67aeb6
X
e6f0e2ce04ca41d8841362ee79b67fb3--515a72262bed47ff81a0a1c27e67aeb6
515a72262bed47ff81a0a1c27e67aeb6--a362e394d4624c7abd0af319c283a5d4
314f441dc923484085d31e0d391a73f4
515a72262bed47ff81a0a1c27e67aeb6--314f441dc923484085d31e0d391a73f4
cfec683fa3144052ba70fdda88036576
RX(x₁)
314f441dc923484085d31e0d391a73f4--cfec683fa3144052ba70fdda88036576
69d5fe9c8ecc418faa192d080ba173b1
RX(theta₁₁)
cfec683fa3144052ba70fdda88036576--69d5fe9c8ecc418faa192d080ba173b1
6e37efc6d6c44641a5be25f2ffd5b427
RY(theta₁₅)
69d5fe9c8ecc418faa192d080ba173b1--6e37efc6d6c44641a5be25f2ffd5b427
68fc2565089240139461fd26160498e2
RX(theta₁₉)
6e37efc6d6c44641a5be25f2ffd5b427--68fc2565089240139461fd26160498e2
d3f05bdeea0b4db882fe791e59e8f708
X
68fc2565089240139461fd26160498e2--d3f05bdeea0b4db882fe791e59e8f708
d3f05bdeea0b4db882fe791e59e8f708--565cf6c96639480d826b325a44f475a9
04407e6ada294795bf8a171a8fc538aa
d3f05bdeea0b4db882fe791e59e8f708--04407e6ada294795bf8a171a8fc538aa
d7e8f06b88b145b2ab364889156d369d
RX(x₁)
04407e6ada294795bf8a171a8fc538aa--d7e8f06b88b145b2ab364889156d369d
ccdbc6f2b06d4621b91497012d28de32
RX(theta₂₁)
d7e8f06b88b145b2ab364889156d369d--ccdbc6f2b06d4621b91497012d28de32
e8ca421c8590493389c778ea8222be2e
RY(theta₂₅)
ccdbc6f2b06d4621b91497012d28de32--e8ca421c8590493389c778ea8222be2e
593ea529d16945adafd754a5d34828fe
RX(theta₂₉)
e8ca421c8590493389c778ea8222be2e--593ea529d16945adafd754a5d34828fe
d5eb2645b85d46a59ab80f86baec2a0b
X
593ea529d16945adafd754a5d34828fe--d5eb2645b85d46a59ab80f86baec2a0b
d5eb2645b85d46a59ab80f86baec2a0b--fe85c407a1954c5db24fd10da9d4b791
91825a8ed77e4b2b96d45541429fe70c
d5eb2645b85d46a59ab80f86baec2a0b--91825a8ed77e4b2b96d45541429fe70c
40b7f095dc5a41e78d29f7059b070a0d
91825a8ed77e4b2b96d45541429fe70c--40b7f095dc5a41e78d29f7059b070a0d
40b7f095dc5a41e78d29f7059b070a0d--a4c459f4ecee4ac09396375bbcc035e8
b9d99184ad484742b9ca1557b957b200
890f796e33ce4ddb88d05fdad065c552
RX(x₂)
09a762479fb44ade81a2fb1ee873052a--890f796e33ce4ddb88d05fdad065c552
40321bb1941146bd9ac37e01c75d816c
3
07094a35b617430caa2728e9f86af2c1
RX(theta₀₂)
890f796e33ce4ddb88d05fdad065c552--07094a35b617430caa2728e9f86af2c1
365cee6b48194c7a904e98524fd491b4
RY(theta₀₆)
07094a35b617430caa2728e9f86af2c1--365cee6b48194c7a904e98524fd491b4
d8bbb8a9f0624e688146151a0e1997c5
RX(theta₀₁₀)
365cee6b48194c7a904e98524fd491b4--d8bbb8a9f0624e688146151a0e1997c5
6373ebf580d94431b2c1b2fbfe784524
d8bbb8a9f0624e688146151a0e1997c5--6373ebf580d94431b2c1b2fbfe784524
69a2fb5a5ae44daa8210d0647c0c1060
X
6373ebf580d94431b2c1b2fbfe784524--69a2fb5a5ae44daa8210d0647c0c1060
69a2fb5a5ae44daa8210d0647c0c1060--314f441dc923484085d31e0d391a73f4
29ce3c61f7e94dacb3acb7f04e1a0026
RX(x₂)
69a2fb5a5ae44daa8210d0647c0c1060--29ce3c61f7e94dacb3acb7f04e1a0026
85816f088b274167b6c7fbe5b6e15fdb
RX(theta₁₂)
29ce3c61f7e94dacb3acb7f04e1a0026--85816f088b274167b6c7fbe5b6e15fdb
5f79c70169094a82943702de04741a7f
RY(theta₁₆)
85816f088b274167b6c7fbe5b6e15fdb--5f79c70169094a82943702de04741a7f
9bb6710da07d490694c06b5518c04d02
RX(theta₁₁₀)
5f79c70169094a82943702de04741a7f--9bb6710da07d490694c06b5518c04d02
3d6203dc77de49ffb4a1b5fae6991ae7
9bb6710da07d490694c06b5518c04d02--3d6203dc77de49ffb4a1b5fae6991ae7
69cf5b8edcb84ae1862a1beeacc1bb57
X
3d6203dc77de49ffb4a1b5fae6991ae7--69cf5b8edcb84ae1862a1beeacc1bb57
69cf5b8edcb84ae1862a1beeacc1bb57--04407e6ada294795bf8a171a8fc538aa
febad808fda840bfb167f3c2738bec9f
RX(x₂)
69cf5b8edcb84ae1862a1beeacc1bb57--febad808fda840bfb167f3c2738bec9f
6083eaaf324042608ee58e4746127fd4
RX(theta₂₂)
febad808fda840bfb167f3c2738bec9f--6083eaaf324042608ee58e4746127fd4
85a72702c2cf4c7d9612b649c32ab4a6
RY(theta₂₆)
6083eaaf324042608ee58e4746127fd4--85a72702c2cf4c7d9612b649c32ab4a6
1d2ce72c1e744a3689f694cbbb3a90ef
RX(theta₂₁₀)
85a72702c2cf4c7d9612b649c32ab4a6--1d2ce72c1e744a3689f694cbbb3a90ef
5cf71b231f304b8fb020b9a3eae197d4
1d2ce72c1e744a3689f694cbbb3a90ef--5cf71b231f304b8fb020b9a3eae197d4
0b2734d758be465da75234e1fe532a96
X
5cf71b231f304b8fb020b9a3eae197d4--0b2734d758be465da75234e1fe532a96
0b2734d758be465da75234e1fe532a96--91825a8ed77e4b2b96d45541429fe70c
8b617d5311d643208b7006d010e5d944
0b2734d758be465da75234e1fe532a96--8b617d5311d643208b7006d010e5d944
8b617d5311d643208b7006d010e5d944--b9d99184ad484742b9ca1557b957b200
c20f0698d95a4f5ab6e4c3669174e4c1
5eed620d40dc4d6f8449f4e68fef43aa
RX(x₃)
40321bb1941146bd9ac37e01c75d816c--5eed620d40dc4d6f8449f4e68fef43aa
033644304bef4a59ae16961bfae071d6
RX(theta₀₃)
5eed620d40dc4d6f8449f4e68fef43aa--033644304bef4a59ae16961bfae071d6
253108cddfcf4202bbc66af6d2ff3a44
RY(theta₀₇)
033644304bef4a59ae16961bfae071d6--253108cddfcf4202bbc66af6d2ff3a44
9322518b0887471c9203b4545aa7c59e
RX(theta₀₁₁)
253108cddfcf4202bbc66af6d2ff3a44--9322518b0887471c9203b4545aa7c59e
3680b69738cf437fadd4958fbe11581c
X
9322518b0887471c9203b4545aa7c59e--3680b69738cf437fadd4958fbe11581c
3680b69738cf437fadd4958fbe11581c--6373ebf580d94431b2c1b2fbfe784524
aad5acc10cc1431483bed4bc47958aa6
3680b69738cf437fadd4958fbe11581c--aad5acc10cc1431483bed4bc47958aa6
1b91d2adebc84ab2a74d41540f2ff544
RX(x₃)
aad5acc10cc1431483bed4bc47958aa6--1b91d2adebc84ab2a74d41540f2ff544
4717821326024749b0435f26580d977f
RX(theta₁₃)
1b91d2adebc84ab2a74d41540f2ff544--4717821326024749b0435f26580d977f
262503cebf3c454991f9980ca7962fbe
RY(theta₁₇)
4717821326024749b0435f26580d977f--262503cebf3c454991f9980ca7962fbe
524e88b7d24346a78e7352a86cf174b5
RX(theta₁₁₁)
262503cebf3c454991f9980ca7962fbe--524e88b7d24346a78e7352a86cf174b5
f0598a50e9ee4b5287ea566a30e463a1
X
524e88b7d24346a78e7352a86cf174b5--f0598a50e9ee4b5287ea566a30e463a1
f0598a50e9ee4b5287ea566a30e463a1--3d6203dc77de49ffb4a1b5fae6991ae7
75eb33db13d84875b949dbe1574efd04
f0598a50e9ee4b5287ea566a30e463a1--75eb33db13d84875b949dbe1574efd04
a8d6930870f1463bbbd0f466c9d2f8dc
RX(x₃)
75eb33db13d84875b949dbe1574efd04--a8d6930870f1463bbbd0f466c9d2f8dc
636330495de04fa0b2a27bdc2e6455cf
RX(theta₂₃)
a8d6930870f1463bbbd0f466c9d2f8dc--636330495de04fa0b2a27bdc2e6455cf
3dd10f23479441929fc3304fa6944958
RY(theta₂₇)
636330495de04fa0b2a27bdc2e6455cf--3dd10f23479441929fc3304fa6944958
0bf392b96dc54b18b34a506c26c8b843
RX(theta₂₁₁)
3dd10f23479441929fc3304fa6944958--0bf392b96dc54b18b34a506c26c8b843
b31c7a64de054b6fbe587bf749454220
X
0bf392b96dc54b18b34a506c26c8b843--b31c7a64de054b6fbe587bf749454220
b31c7a64de054b6fbe587bf749454220--5cf71b231f304b8fb020b9a3eae197d4
be8e73619b5e446ba958c3d8888f5cef
b31c7a64de054b6fbe587bf749454220--be8e73619b5e446ba958c3d8888f5cef
4a9d229035684dfd8f1b150fee1294d2
be8e73619b5e446ba958c3d8888f5cef--4a9d229035684dfd8f1b150fee1294d2
4a9d229035684dfd8f1b150fee1294d2--c20f0698d95a4f5ab6e4c3669174e4c1
Training
Then we can set up the training part:
opt = torch . optim . Adam ( model . parameters (), lr = lr )
criterion = nn . CrossEntropyLoss ()
def cross_entropy ( model : nn . Module , data : Tensor ) -> tuple [ Tensor , dict ]:
x , y = data
out = model ( x )
loss = criterion ( out , y )
return loss , {}
train_config = TrainConfig ( max_iter = n_epochs , print_every = 10 , create_subfolder_per_run = True )
Trainer . set_use_grad ( True )
trainer = Trainer ( model = model , optimizer = opt , config = train_config , loss_fn = cross_entropy )
res_train = trainer . fit ( dataloader )
Inference
Finally, we can apply our model on the test set and check the score.
X_test , y_test = dataset . X_test , dataset . y_test
preds_test = torch . argmax ( torch . softmax ( model ( X_test ), dim = 1 ), dim = 1 )
accuracy_test = ( preds_test == y_test ) . type ( torch . float32 ) . mean ()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005