Classification with QNN
In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.
Dataset
We will use the Iris dataset separated into training and testing sets.
The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica).
When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).
import random
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader , Dataset
from qadence import QNN , RX , FeatureParameter , QuantumCircuit , Z , chain , hea , kron
from qadence.ml_tools import TrainConfig , Trainer
class IrisDataset ( Dataset ):
"""The Iris dataset split into a training set and a test set.
A StandardScaler is applied prior to applying models.
"""
def __init__ ( self ):
X , y = load_iris ( return_X_y = True )
X_train , X_test , y_train , y_test = train_test_split ( X , y , test_size = 0.33 , random_state = 42 )
self . scaler = StandardScaler ()
self . scaler . fit ( X_train )
self . X = torch . tensor ( self . scaler . transform ( X_train ), requires_grad = False )
self . y = torch . tensor ( y_train , requires_grad = False )
self . X_test = torch . tensor ( self . scaler . transform ( X_test ), requires_grad = False )
self . y_test = torch . tensor ( y_test , requires_grad = False )
def __getitem__ ( self , index ) -> tuple [ Tensor , Tensor ]:
return self . X [ index ], self . y [ index ]
def __len__ ( self ) -> int :
return len ( self . y )
n_features = 4 # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset ()
dataloader = DataLoader ( dataset , batch_size = 20 , shuffle = True )
Hybrid QNN
We set up the QNN part composed of multiple feature map layers, each followed by a variational layer.
The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components.
The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\) .
Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\) , \(l = W * o + b\) . To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\) .
Note softmax is not applied during training with the cross-entropy loss .
feature_parameters = [ FeatureParameter ( f "x_ { i } " ) for i in range ( n_features )]
fm_layer = RX ( 0 , feature_parameters [ 0 ])
for q in range ( 1 , n_features ):
fm_layer = kron ( fm_layer , RX ( q , feature_parameters [ q ]))
ansatz_layers = [
hea ( n_qubits = n_features , depth = 1 , param_prefix = f "theta_ { layer } " )
for layer in range ( n_layers )
]
blocks = chain ( fm_layer , ansatz_layers [ 0 ])
for layer in range ( 1 , n_layers ):
blocks = chain ( blocks , fm_layer , ansatz_layers [ layer ])
qc = QuantumCircuit ( n_features , blocks )
qnn = QNN ( circuit = qc , observable = Z ( 0 ), inputs = [ f "x_ { i } " for i in range ( n_features )])
model = nn . Sequential ( qnn , nn . Linear ( 1 , n_neurons_final_linear_layer ))
Below is a visualization of the QNN:
%3
cluster_83e8037eae3f4fd28bbad32c526b7af4
HEA
cluster_40740bc78131474d8871540fbb711d1b
Obs.
cluster_55f68cf903304ed6a2ebee5fc84a8852
HEA
cluster_d3900916aa13437f8bc6a5acec213d00
HEA
21a18f1d9eac487fab534fc463b889ca
0
3965a1192aab492cbcece7e08a08ec26
RX(x₀)
21a18f1d9eac487fab534fc463b889ca--3965a1192aab492cbcece7e08a08ec26
835c767069134f0686ec9546ae8c0f47
1
71ecd301bbd541218f5dd62ed055e982
RX(theta₀₀)
3965a1192aab492cbcece7e08a08ec26--71ecd301bbd541218f5dd62ed055e982
59277d88ee56417c88c30923bcfbde36
RY(theta₀₄)
71ecd301bbd541218f5dd62ed055e982--59277d88ee56417c88c30923bcfbde36
959759007b8d4547918b5aa77831e4a4
RX(theta₀₈)
59277d88ee56417c88c30923bcfbde36--959759007b8d4547918b5aa77831e4a4
707a5acb78b8454b960aca7de9147562
959759007b8d4547918b5aa77831e4a4--707a5acb78b8454b960aca7de9147562
14ed0e55202d44a4bef13a4338e6e37f
707a5acb78b8454b960aca7de9147562--14ed0e55202d44a4bef13a4338e6e37f
344b77ab75364fefa776aedcfd925ce8
RX(x₀)
14ed0e55202d44a4bef13a4338e6e37f--344b77ab75364fefa776aedcfd925ce8
266fdd4a36e74f09a331a8234d5815ec
RX(theta₁₀)
344b77ab75364fefa776aedcfd925ce8--266fdd4a36e74f09a331a8234d5815ec
0ca8d897e9f14eb7ad6ddbb2d1f975cc
RY(theta₁₄)
266fdd4a36e74f09a331a8234d5815ec--0ca8d897e9f14eb7ad6ddbb2d1f975cc
868c0dd349a74bb3a1e85c42f1c7297a
RX(theta₁₈)
0ca8d897e9f14eb7ad6ddbb2d1f975cc--868c0dd349a74bb3a1e85c42f1c7297a
28b25f1950914dbbaf44a9320526dea8
868c0dd349a74bb3a1e85c42f1c7297a--28b25f1950914dbbaf44a9320526dea8
0a4ce9e5bb394b0eb19b581d4fd06566
28b25f1950914dbbaf44a9320526dea8--0a4ce9e5bb394b0eb19b581d4fd06566
98bb6817ffe647f28a5aa20b5dee9ae3
RX(x₀)
0a4ce9e5bb394b0eb19b581d4fd06566--98bb6817ffe647f28a5aa20b5dee9ae3
d9954a4940f247f1a593768641300e7b
RX(theta₂₀)
98bb6817ffe647f28a5aa20b5dee9ae3--d9954a4940f247f1a593768641300e7b
017263321eb3445a9064816e3da8dfee
RY(theta₂₄)
d9954a4940f247f1a593768641300e7b--017263321eb3445a9064816e3da8dfee
16df1e0de80d4a5a81467179c87b1de3
RX(theta₂₈)
017263321eb3445a9064816e3da8dfee--16df1e0de80d4a5a81467179c87b1de3
2faf6ce1b27a4305aba0cd8e27f28ecc
16df1e0de80d4a5a81467179c87b1de3--2faf6ce1b27a4305aba0cd8e27f28ecc
0cb7ba36d0ca4f01a02ed84f6ce97a4c
2faf6ce1b27a4305aba0cd8e27f28ecc--0cb7ba36d0ca4f01a02ed84f6ce97a4c
ed12b9f295954f8ebc6481adb98b904a
Z
0cb7ba36d0ca4f01a02ed84f6ce97a4c--ed12b9f295954f8ebc6481adb98b904a
a3f005223c20498caf21b71f2fb28ac4
ed12b9f295954f8ebc6481adb98b904a--a3f005223c20498caf21b71f2fb28ac4
302300ca054543e5a3b934fe76902b52
b095b1a196a64b2ebc555444db76a9cc
RX(x₁)
835c767069134f0686ec9546ae8c0f47--b095b1a196a64b2ebc555444db76a9cc
9de7ba96d88949028354adf4f44578bb
2
a9b155e2492741c99a558d2894d54e26
RX(theta₀₁)
b095b1a196a64b2ebc555444db76a9cc--a9b155e2492741c99a558d2894d54e26
977ad60838ef46e98611a245f720c2e1
RY(theta₀₅)
a9b155e2492741c99a558d2894d54e26--977ad60838ef46e98611a245f720c2e1
a5046ac182c940a2a0edb1ca4181a849
RX(theta₀₉)
977ad60838ef46e98611a245f720c2e1--a5046ac182c940a2a0edb1ca4181a849
ccbeae8fd628499ba1efeab863f76397
X
a5046ac182c940a2a0edb1ca4181a849--ccbeae8fd628499ba1efeab863f76397
ccbeae8fd628499ba1efeab863f76397--707a5acb78b8454b960aca7de9147562
f9629250272b4906b3e30bf1bdb77bb3
ccbeae8fd628499ba1efeab863f76397--f9629250272b4906b3e30bf1bdb77bb3
258e34bf98da4a319865018527a63b64
RX(x₁)
f9629250272b4906b3e30bf1bdb77bb3--258e34bf98da4a319865018527a63b64
52de5c3612a84757b2cb63c14a28da8d
RX(theta₁₁)
258e34bf98da4a319865018527a63b64--52de5c3612a84757b2cb63c14a28da8d
6e85839c8d26462aa5c576421d25ead6
RY(theta₁₅)
52de5c3612a84757b2cb63c14a28da8d--6e85839c8d26462aa5c576421d25ead6
94f5c87a504f49f683386a3264636e85
RX(theta₁₉)
6e85839c8d26462aa5c576421d25ead6--94f5c87a504f49f683386a3264636e85
ba1e2cf536c947de9db683f81eb42aeb
X
94f5c87a504f49f683386a3264636e85--ba1e2cf536c947de9db683f81eb42aeb
ba1e2cf536c947de9db683f81eb42aeb--28b25f1950914dbbaf44a9320526dea8
07c737a6294b4d93882d95ed200a9d85
ba1e2cf536c947de9db683f81eb42aeb--07c737a6294b4d93882d95ed200a9d85
2b3b2e801ff04b31bc94eb450dbcef54
RX(x₁)
07c737a6294b4d93882d95ed200a9d85--2b3b2e801ff04b31bc94eb450dbcef54
3701b4dc62a64d708ccb6864b8bb9b92
RX(theta₂₁)
2b3b2e801ff04b31bc94eb450dbcef54--3701b4dc62a64d708ccb6864b8bb9b92
3c4b9c4209f743ba9d82fab48c618f25
RY(theta₂₅)
3701b4dc62a64d708ccb6864b8bb9b92--3c4b9c4209f743ba9d82fab48c618f25
1618ad6da8674e76a0f5c446ba4cd5c4
RX(theta₂₉)
3c4b9c4209f743ba9d82fab48c618f25--1618ad6da8674e76a0f5c446ba4cd5c4
d003234e1d80462eadfd9a7a694842fd
X
1618ad6da8674e76a0f5c446ba4cd5c4--d003234e1d80462eadfd9a7a694842fd
d003234e1d80462eadfd9a7a694842fd--2faf6ce1b27a4305aba0cd8e27f28ecc
ee687190c24c4535bc11d6bd71f9a9df
d003234e1d80462eadfd9a7a694842fd--ee687190c24c4535bc11d6bd71f9a9df
8dcb8f999a2d499ebb87132a852524fc
ee687190c24c4535bc11d6bd71f9a9df--8dcb8f999a2d499ebb87132a852524fc
8dcb8f999a2d499ebb87132a852524fc--302300ca054543e5a3b934fe76902b52
9d584dcc0bf549ac9554627c2c7caeb5
1c7d43cfce6e4bb18f147b7235211780
RX(x₂)
9de7ba96d88949028354adf4f44578bb--1c7d43cfce6e4bb18f147b7235211780
1228dd137fde4a92979033552b30904c
3
cb071ddb85cf4616926f29676b996b1f
RX(theta₀₂)
1c7d43cfce6e4bb18f147b7235211780--cb071ddb85cf4616926f29676b996b1f
b958646cb88d4448bfacabfef08787fd
RY(theta₀₆)
cb071ddb85cf4616926f29676b996b1f--b958646cb88d4448bfacabfef08787fd
18ef664a7b044b7a85458c03c5958fae
RX(theta₀₁₀)
b958646cb88d4448bfacabfef08787fd--18ef664a7b044b7a85458c03c5958fae
95dfcb569f5140b28c6f221896760d7e
18ef664a7b044b7a85458c03c5958fae--95dfcb569f5140b28c6f221896760d7e
26f58edb1f144455ab7270b1e554e8e2
X
95dfcb569f5140b28c6f221896760d7e--26f58edb1f144455ab7270b1e554e8e2
26f58edb1f144455ab7270b1e554e8e2--f9629250272b4906b3e30bf1bdb77bb3
9f5f3e78109442ffaed4c62a539d2b72
RX(x₂)
26f58edb1f144455ab7270b1e554e8e2--9f5f3e78109442ffaed4c62a539d2b72
96de85ccf9c0429cb7bfc6252e2b3b6d
RX(theta₁₂)
9f5f3e78109442ffaed4c62a539d2b72--96de85ccf9c0429cb7bfc6252e2b3b6d
21a54944e1be469aabca7cc448f67f62
RY(theta₁₆)
96de85ccf9c0429cb7bfc6252e2b3b6d--21a54944e1be469aabca7cc448f67f62
ddc510cbd9d84fd39320645a9d23741b
RX(theta₁₁₀)
21a54944e1be469aabca7cc448f67f62--ddc510cbd9d84fd39320645a9d23741b
ec9afa467e4c4c4e9b52d1e32674ec70
ddc510cbd9d84fd39320645a9d23741b--ec9afa467e4c4c4e9b52d1e32674ec70
2f46fd99685f4aa8ad02a19e9180b57e
X
ec9afa467e4c4c4e9b52d1e32674ec70--2f46fd99685f4aa8ad02a19e9180b57e
2f46fd99685f4aa8ad02a19e9180b57e--07c737a6294b4d93882d95ed200a9d85
5e51e574c7e34911a0d66709c672ef12
RX(x₂)
2f46fd99685f4aa8ad02a19e9180b57e--5e51e574c7e34911a0d66709c672ef12
7d72c965d1f243a7ada9275d851fea1e
RX(theta₂₂)
5e51e574c7e34911a0d66709c672ef12--7d72c965d1f243a7ada9275d851fea1e
cdc21ba0d3be4b2499761f13901777a6
RY(theta₂₆)
7d72c965d1f243a7ada9275d851fea1e--cdc21ba0d3be4b2499761f13901777a6
13aa3da8b19d4ffba42d95e7885e0df1
RX(theta₂₁₀)
cdc21ba0d3be4b2499761f13901777a6--13aa3da8b19d4ffba42d95e7885e0df1
b7001b1d36fa497ab9683e40a697ce7d
13aa3da8b19d4ffba42d95e7885e0df1--b7001b1d36fa497ab9683e40a697ce7d
5869caa74447486a85e5e5fee64686ce
X
b7001b1d36fa497ab9683e40a697ce7d--5869caa74447486a85e5e5fee64686ce
5869caa74447486a85e5e5fee64686ce--ee687190c24c4535bc11d6bd71f9a9df
f94201fe82fc45b483d0b393bd82d5ad
5869caa74447486a85e5e5fee64686ce--f94201fe82fc45b483d0b393bd82d5ad
f94201fe82fc45b483d0b393bd82d5ad--9d584dcc0bf549ac9554627c2c7caeb5
fe9df131bd0f47eca34d64aa5159ac73
9da0c4fa5ed844b58401105bf19193f9
RX(x₃)
1228dd137fde4a92979033552b30904c--9da0c4fa5ed844b58401105bf19193f9
827a05c9975546dba42c32b2893fe597
RX(theta₀₃)
9da0c4fa5ed844b58401105bf19193f9--827a05c9975546dba42c32b2893fe597
9415f62a0837429da3e1afa60c2ef6b9
RY(theta₀₇)
827a05c9975546dba42c32b2893fe597--9415f62a0837429da3e1afa60c2ef6b9
e3d9e93299d84aadb57fde37f0109a37
RX(theta₀₁₁)
9415f62a0837429da3e1afa60c2ef6b9--e3d9e93299d84aadb57fde37f0109a37
7deab2ab92fb45838baee53cbbd69b36
X
e3d9e93299d84aadb57fde37f0109a37--7deab2ab92fb45838baee53cbbd69b36
7deab2ab92fb45838baee53cbbd69b36--95dfcb569f5140b28c6f221896760d7e
42e9009dda45479ab3e661c87057186d
7deab2ab92fb45838baee53cbbd69b36--42e9009dda45479ab3e661c87057186d
ee8e6421e6764e2c9d21e81a18a12921
RX(x₃)
42e9009dda45479ab3e661c87057186d--ee8e6421e6764e2c9d21e81a18a12921
e23f2cea35f44cf9bf5e45b4dc7e13d2
RX(theta₁₃)
ee8e6421e6764e2c9d21e81a18a12921--e23f2cea35f44cf9bf5e45b4dc7e13d2
91ca9fbcc4b246ee89513a03c1400a5d
RY(theta₁₇)
e23f2cea35f44cf9bf5e45b4dc7e13d2--91ca9fbcc4b246ee89513a03c1400a5d
a9e964a512d04787af64724e9184f5e9
RX(theta₁₁₁)
91ca9fbcc4b246ee89513a03c1400a5d--a9e964a512d04787af64724e9184f5e9
7d67ee20469f4fe1bb0f7c5f58f38ab5
X
a9e964a512d04787af64724e9184f5e9--7d67ee20469f4fe1bb0f7c5f58f38ab5
7d67ee20469f4fe1bb0f7c5f58f38ab5--ec9afa467e4c4c4e9b52d1e32674ec70
a34c68aef82f4cd2aaf3d2cf93241cd2
7d67ee20469f4fe1bb0f7c5f58f38ab5--a34c68aef82f4cd2aaf3d2cf93241cd2
50e525762aeb47d29fc3edde3fb4b25c
RX(x₃)
a34c68aef82f4cd2aaf3d2cf93241cd2--50e525762aeb47d29fc3edde3fb4b25c
6d8650bed8fb49bfa4494a98d77670ec
RX(theta₂₃)
50e525762aeb47d29fc3edde3fb4b25c--6d8650bed8fb49bfa4494a98d77670ec
dd053f2cab1a4adfa7409d032ce92044
RY(theta₂₇)
6d8650bed8fb49bfa4494a98d77670ec--dd053f2cab1a4adfa7409d032ce92044
40d452f07c894310b734b1bef6f67ee4
RX(theta₂₁₁)
dd053f2cab1a4adfa7409d032ce92044--40d452f07c894310b734b1bef6f67ee4
b86ac3d7b3394847856338b21af8c381
X
40d452f07c894310b734b1bef6f67ee4--b86ac3d7b3394847856338b21af8c381
b86ac3d7b3394847856338b21af8c381--b7001b1d36fa497ab9683e40a697ce7d
cc53a6491d414bf7b39d8831e77d6752
b86ac3d7b3394847856338b21af8c381--cc53a6491d414bf7b39d8831e77d6752
f821ac849f4041fc91f8a1c394df313a
cc53a6491d414bf7b39d8831e77d6752--f821ac849f4041fc91f8a1c394df313a
f821ac849f4041fc91f8a1c394df313a--fe9df131bd0f47eca34d64aa5159ac73
Training
Then we can set up the training part:
opt = torch . optim . Adam ( model . parameters (), lr = lr )
criterion = nn . CrossEntropyLoss ()
def cross_entropy ( model : nn . Module , data : Tensor ) -> tuple [ Tensor , dict ]:
x , y = data
out = model ( x )
loss = criterion ( out , y )
return loss , {}
train_config = TrainConfig ( max_iter = n_epochs , print_every = 10 , create_subfolder_per_run = True )
Trainer . set_use_grad ( True )
trainer = Trainer ( model = model , optimizer = opt , config = train_config , loss_fn = cross_entropy )
res_train = trainer . fit ( dataloader )
Inference
Finally, we can apply our model on the test set and check the score.
X_test , y_test = dataset . X_test , dataset . y_test
preds_test = torch . argmax ( torch . softmax ( model ( X_test ), dim = 1 ), dim = 1 )
accuracy_test = ( preds_test == y_test ) . type ( torch . float32 ) . mean ()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005