Classification with QNN
In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.
Dataset
We will use the Iris dataset separated into training and testing sets.
The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica).
When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).
import random
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader , Dataset
from qadence import QNN , RX , FeatureParameter , QuantumCircuit , Z , chain , hea , kron
from qadence.ml_tools import TrainConfig , Trainer
class IrisDataset ( Dataset ):
"""The Iris dataset split into a training set and a test set.
A StandardScaler is applied prior to applying models.
"""
def __init__ ( self ):
X , y = load_iris ( return_X_y = True )
X_train , X_test , y_train , y_test = train_test_split ( X , y , test_size = 0.33 , random_state = 42 )
self . scaler = StandardScaler ()
self . scaler . fit ( X_train )
self . X = torch . tensor ( self . scaler . transform ( X_train ), requires_grad = False )
self . y = torch . tensor ( y_train , requires_grad = False )
self . X_test = torch . tensor ( self . scaler . transform ( X_test ), requires_grad = False )
self . y_test = torch . tensor ( y_test , requires_grad = False )
def __getitem__ ( self , index ) -> tuple [ Tensor , Tensor ]:
return self . X [ index ], self . y [ index ]
def __len__ ( self ) -> int :
return len ( self . y )
n_features = 4 # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset ()
dataloader = DataLoader ( dataset , batch_size = 20 , shuffle = True )
Hybrid QNN
We set up the QNN part composed of multiple feature map layers, each followed by a variational layer.
The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components.
The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\) .
Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\) , \(l = W * o + b\) . To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\) .
Note softmax is not applied during training with the cross-entropy loss .
feature_parameters = [ FeatureParameter ( f "x_ { i } " ) for i in range ( n_features )]
fm_layer = RX ( 0 , feature_parameters [ 0 ])
for q in range ( 1 , n_features ):
fm_layer = kron ( fm_layer , RX ( q , feature_parameters [ q ]))
ansatz_layers = [
hea ( n_qubits = n_features , depth = 1 , param_prefix = f "theta_ { layer } " )
for layer in range ( n_layers )
]
blocks = chain ( fm_layer , ansatz_layers [ 0 ])
for layer in range ( 1 , n_layers ):
blocks = chain ( blocks , fm_layer , ansatz_layers [ layer ])
qc = QuantumCircuit ( n_features , blocks )
qnn = QNN ( circuit = qc , observable = Z ( 0 ), inputs = [ f "x_ { i } " for i in range ( n_features )])
model = nn . Sequential ( qnn , nn . Linear ( 1 , n_neurons_final_linear_layer ))
Below is a visualization of the QNN:
%3
cluster_4520a4655b6a4883bcf93f9be901fe60
HEA
cluster_f61d0e2136f64afea30d57129a1f57d8
Obs.
cluster_4d8d9194641d4ec6b4a1dedd35b0c63b
HEA
cluster_0ad089ff759d4cd5bc7cbe46ad174481
HEA
b26647b934f3459abc34ef6d16164774
0
dfbf56519f784ac997f79195ad879e20
RX(x₀)
b26647b934f3459abc34ef6d16164774--dfbf56519f784ac997f79195ad879e20
acab9b5e6c6c4cd0b1a8a650eb8aff5b
1
b628078078e14a11ae25c41cdb98b216
RX(theta₀₀)
dfbf56519f784ac997f79195ad879e20--b628078078e14a11ae25c41cdb98b216
1080d3114e894c169d3c6ce68a74d406
RY(theta₀₄)
b628078078e14a11ae25c41cdb98b216--1080d3114e894c169d3c6ce68a74d406
420a8d70b31d4708979d1e74e768d0c0
RX(theta₀₈)
1080d3114e894c169d3c6ce68a74d406--420a8d70b31d4708979d1e74e768d0c0
746b72315f0245e6bfdece5bcaeab92e
420a8d70b31d4708979d1e74e768d0c0--746b72315f0245e6bfdece5bcaeab92e
467d7ec525794434a7f7fca3c9f515a0
746b72315f0245e6bfdece5bcaeab92e--467d7ec525794434a7f7fca3c9f515a0
b5649115193d4ee3915b841a98189597
RX(x₀)
467d7ec525794434a7f7fca3c9f515a0--b5649115193d4ee3915b841a98189597
607966122a9a49f4835fbdfe142e4b7c
RX(theta₁₀)
b5649115193d4ee3915b841a98189597--607966122a9a49f4835fbdfe142e4b7c
eba3e33474fb4805825c5028d39a5002
RY(theta₁₄)
607966122a9a49f4835fbdfe142e4b7c--eba3e33474fb4805825c5028d39a5002
0dd6fcb6031a4a518986db1229f95de4
RX(theta₁₈)
eba3e33474fb4805825c5028d39a5002--0dd6fcb6031a4a518986db1229f95de4
d2d9e5b45b9840f0b81b252de14f5389
0dd6fcb6031a4a518986db1229f95de4--d2d9e5b45b9840f0b81b252de14f5389
af5c40b72a284fafa3965c3268dbf4a6
d2d9e5b45b9840f0b81b252de14f5389--af5c40b72a284fafa3965c3268dbf4a6
cbf562c8e0384e269a81f391113491c7
RX(x₀)
af5c40b72a284fafa3965c3268dbf4a6--cbf562c8e0384e269a81f391113491c7
a2ac4fd09c264dc482867b11f8955f32
RX(theta₂₀)
cbf562c8e0384e269a81f391113491c7--a2ac4fd09c264dc482867b11f8955f32
fa2123a35664468d990241dfac1fc4fe
RY(theta₂₄)
a2ac4fd09c264dc482867b11f8955f32--fa2123a35664468d990241dfac1fc4fe
2af5ed1010f248c5acf385d7345469bf
RX(theta₂₈)
fa2123a35664468d990241dfac1fc4fe--2af5ed1010f248c5acf385d7345469bf
ed7b071058504c10924bd56ce31d400a
2af5ed1010f248c5acf385d7345469bf--ed7b071058504c10924bd56ce31d400a
ac92dc15a3b74978a79aa1606e291adc
ed7b071058504c10924bd56ce31d400a--ac92dc15a3b74978a79aa1606e291adc
fd4cfd8fd053413899bd1a97bd5a1891
Z
ac92dc15a3b74978a79aa1606e291adc--fd4cfd8fd053413899bd1a97bd5a1891
d036cd5817b1497f8507ab3742275eb3
fd4cfd8fd053413899bd1a97bd5a1891--d036cd5817b1497f8507ab3742275eb3
bfa3d04baf274ee094056c273796e8c0
b47595787978471eb310b830c9ca6a12
RX(x₁)
acab9b5e6c6c4cd0b1a8a650eb8aff5b--b47595787978471eb310b830c9ca6a12
13ef4fa0da034c6a96bd5277edd2b7a3
2
3b612a6ae8324b00b2ee05969cf153be
RX(theta₀₁)
b47595787978471eb310b830c9ca6a12--3b612a6ae8324b00b2ee05969cf153be
8f7a558a448f4a6ebf4c0df970e572e5
RY(theta₀₅)
3b612a6ae8324b00b2ee05969cf153be--8f7a558a448f4a6ebf4c0df970e572e5
f486d3aa43684b808700711d79f19dc4
RX(theta₀₉)
8f7a558a448f4a6ebf4c0df970e572e5--f486d3aa43684b808700711d79f19dc4
0fd498c40e6a49968b727b6caf19abbd
X
f486d3aa43684b808700711d79f19dc4--0fd498c40e6a49968b727b6caf19abbd
0fd498c40e6a49968b727b6caf19abbd--746b72315f0245e6bfdece5bcaeab92e
dba17f81ea754d8f93d4e05a445f1276
0fd498c40e6a49968b727b6caf19abbd--dba17f81ea754d8f93d4e05a445f1276
de3785009912432ab5e8bf7d4bd9a285
RX(x₁)
dba17f81ea754d8f93d4e05a445f1276--de3785009912432ab5e8bf7d4bd9a285
e05104567b2449678f9402d7290ed74d
RX(theta₁₁)
de3785009912432ab5e8bf7d4bd9a285--e05104567b2449678f9402d7290ed74d
60321051a02644c097244157ed6c2124
RY(theta₁₅)
e05104567b2449678f9402d7290ed74d--60321051a02644c097244157ed6c2124
a8ffa82e279742f5aa89d4292d4654f8
RX(theta₁₉)
60321051a02644c097244157ed6c2124--a8ffa82e279742f5aa89d4292d4654f8
fa8bd10a76bf4b288eb3bbf318223348
X
a8ffa82e279742f5aa89d4292d4654f8--fa8bd10a76bf4b288eb3bbf318223348
fa8bd10a76bf4b288eb3bbf318223348--d2d9e5b45b9840f0b81b252de14f5389
e199f8170521469daf7635302d89da8f
fa8bd10a76bf4b288eb3bbf318223348--e199f8170521469daf7635302d89da8f
03cf49075fcd4c3fbfb0350655cca1ca
RX(x₁)
e199f8170521469daf7635302d89da8f--03cf49075fcd4c3fbfb0350655cca1ca
55f058d98d1f4caabc5778ae7059532a
RX(theta₂₁)
03cf49075fcd4c3fbfb0350655cca1ca--55f058d98d1f4caabc5778ae7059532a
ed8aa5c9209649bca64b057fd660d6f5
RY(theta₂₅)
55f058d98d1f4caabc5778ae7059532a--ed8aa5c9209649bca64b057fd660d6f5
882b9bac7ff44560a023d7342ed130e2
RX(theta₂₉)
ed8aa5c9209649bca64b057fd660d6f5--882b9bac7ff44560a023d7342ed130e2
61a252c8a22c408bb474a7d67654c7bb
X
882b9bac7ff44560a023d7342ed130e2--61a252c8a22c408bb474a7d67654c7bb
61a252c8a22c408bb474a7d67654c7bb--ed7b071058504c10924bd56ce31d400a
477b2051e1d44a2ca13aecdd82b178ef
61a252c8a22c408bb474a7d67654c7bb--477b2051e1d44a2ca13aecdd82b178ef
5804fa20a48e4777add588de1d168eee
477b2051e1d44a2ca13aecdd82b178ef--5804fa20a48e4777add588de1d168eee
5804fa20a48e4777add588de1d168eee--bfa3d04baf274ee094056c273796e8c0
09744697c0c44c1aad0610ee4b2b1a16
e26fc94d201c4cc4a3e5e54f9a927750
RX(x₂)
13ef4fa0da034c6a96bd5277edd2b7a3--e26fc94d201c4cc4a3e5e54f9a927750
93523748311e42d8a28471cc6df33e9b
3
6ea49af6ca20471b9636e531f80aef3b
RX(theta₀₂)
e26fc94d201c4cc4a3e5e54f9a927750--6ea49af6ca20471b9636e531f80aef3b
35987dacb63442d3be542a595d14d5d9
RY(theta₀₆)
6ea49af6ca20471b9636e531f80aef3b--35987dacb63442d3be542a595d14d5d9
d446da227d864a6ea016881056f865d5
RX(theta₀₁₀)
35987dacb63442d3be542a595d14d5d9--d446da227d864a6ea016881056f865d5
3c159608101d4541b444ce20d7d491b5
d446da227d864a6ea016881056f865d5--3c159608101d4541b444ce20d7d491b5
536a773c516b4ffcaba57e5e32c67817
X
3c159608101d4541b444ce20d7d491b5--536a773c516b4ffcaba57e5e32c67817
536a773c516b4ffcaba57e5e32c67817--dba17f81ea754d8f93d4e05a445f1276
9cbeaf399f6b48bebe34be69f51aaab3
RX(x₂)
536a773c516b4ffcaba57e5e32c67817--9cbeaf399f6b48bebe34be69f51aaab3
2ed471500dad400aaf3a4c2311930b7e
RX(theta₁₂)
9cbeaf399f6b48bebe34be69f51aaab3--2ed471500dad400aaf3a4c2311930b7e
c8431a8a740a4bd5bdbaa147a5db662f
RY(theta₁₆)
2ed471500dad400aaf3a4c2311930b7e--c8431a8a740a4bd5bdbaa147a5db662f
7ff7b358913a4148b2fbb5920f13b1c4
RX(theta₁₁₀)
c8431a8a740a4bd5bdbaa147a5db662f--7ff7b358913a4148b2fbb5920f13b1c4
abde8c59c62d4ebeb65241e316bb9fa9
7ff7b358913a4148b2fbb5920f13b1c4--abde8c59c62d4ebeb65241e316bb9fa9
7bece7363e534d6e9336650c6ddeb287
X
abde8c59c62d4ebeb65241e316bb9fa9--7bece7363e534d6e9336650c6ddeb287
7bece7363e534d6e9336650c6ddeb287--e199f8170521469daf7635302d89da8f
f6163e4a066d4bb892073f0ac561da93
RX(x₂)
7bece7363e534d6e9336650c6ddeb287--f6163e4a066d4bb892073f0ac561da93
14bf1b76022e4891881b6d3ece0948aa
RX(theta₂₂)
f6163e4a066d4bb892073f0ac561da93--14bf1b76022e4891881b6d3ece0948aa
ef6d846e14674194a4d50e4e1f09bcda
RY(theta₂₆)
14bf1b76022e4891881b6d3ece0948aa--ef6d846e14674194a4d50e4e1f09bcda
d9a168e10fc04b1496017247d09472ea
RX(theta₂₁₀)
ef6d846e14674194a4d50e4e1f09bcda--d9a168e10fc04b1496017247d09472ea
9fe9b1fdbc164914bdd963c88489b703
d9a168e10fc04b1496017247d09472ea--9fe9b1fdbc164914bdd963c88489b703
95397166c70f42c69d125df3a706e2b4
X
9fe9b1fdbc164914bdd963c88489b703--95397166c70f42c69d125df3a706e2b4
95397166c70f42c69d125df3a706e2b4--477b2051e1d44a2ca13aecdd82b178ef
b60abd46c94f48d79e0463398488fea9
95397166c70f42c69d125df3a706e2b4--b60abd46c94f48d79e0463398488fea9
b60abd46c94f48d79e0463398488fea9--09744697c0c44c1aad0610ee4b2b1a16
c5fbdbb7df604da5a075a0173402d14c
9890d0b5e77a472dbdd0a0625dbd70cf
RX(x₃)
93523748311e42d8a28471cc6df33e9b--9890d0b5e77a472dbdd0a0625dbd70cf
41f7c7adcf874d25bb8f7108413d3297
RX(theta₀₃)
9890d0b5e77a472dbdd0a0625dbd70cf--41f7c7adcf874d25bb8f7108413d3297
cd710df9557b4fc19677546e90eefdff
RY(theta₀₇)
41f7c7adcf874d25bb8f7108413d3297--cd710df9557b4fc19677546e90eefdff
f6665c8cebd5447aaa4778ff123b7b66
RX(theta₀₁₁)
cd710df9557b4fc19677546e90eefdff--f6665c8cebd5447aaa4778ff123b7b66
ac32cd465ef94eb5bfdb7cb0ed0bd9cb
X
f6665c8cebd5447aaa4778ff123b7b66--ac32cd465ef94eb5bfdb7cb0ed0bd9cb
ac32cd465ef94eb5bfdb7cb0ed0bd9cb--3c159608101d4541b444ce20d7d491b5
cf26576c3ec34dd3bd5a308805c565f3
ac32cd465ef94eb5bfdb7cb0ed0bd9cb--cf26576c3ec34dd3bd5a308805c565f3
2fc3d7923a7e4e9bbf9aa01bbb129c1e
RX(x₃)
cf26576c3ec34dd3bd5a308805c565f3--2fc3d7923a7e4e9bbf9aa01bbb129c1e
64ae75bc16084249a32c7e54aff17de8
RX(theta₁₃)
2fc3d7923a7e4e9bbf9aa01bbb129c1e--64ae75bc16084249a32c7e54aff17de8
2969263868a6484582ecf586470c2251
RY(theta₁₇)
64ae75bc16084249a32c7e54aff17de8--2969263868a6484582ecf586470c2251
9a76c7a38055430aa13a7dd14c4e6177
RX(theta₁₁₁)
2969263868a6484582ecf586470c2251--9a76c7a38055430aa13a7dd14c4e6177
2f7edc1a9a9f422298b8f519b31f254f
X
9a76c7a38055430aa13a7dd14c4e6177--2f7edc1a9a9f422298b8f519b31f254f
2f7edc1a9a9f422298b8f519b31f254f--abde8c59c62d4ebeb65241e316bb9fa9
de118637a35240cb9e0b7fa6acfa0e2b
2f7edc1a9a9f422298b8f519b31f254f--de118637a35240cb9e0b7fa6acfa0e2b
51a26f4c393f4c8fa36b7ff573b848e9
RX(x₃)
de118637a35240cb9e0b7fa6acfa0e2b--51a26f4c393f4c8fa36b7ff573b848e9
361e6d3282994df9a58d2e843187ebe0
RX(theta₂₃)
51a26f4c393f4c8fa36b7ff573b848e9--361e6d3282994df9a58d2e843187ebe0
686fabdb10f14506a9d4202f989569ee
RY(theta₂₇)
361e6d3282994df9a58d2e843187ebe0--686fabdb10f14506a9d4202f989569ee
99d0681933d04351949c1c8057434a00
RX(theta₂₁₁)
686fabdb10f14506a9d4202f989569ee--99d0681933d04351949c1c8057434a00
f759cba3e7bc4e5e84c3b683f899b295
X
99d0681933d04351949c1c8057434a00--f759cba3e7bc4e5e84c3b683f899b295
f759cba3e7bc4e5e84c3b683f899b295--9fe9b1fdbc164914bdd963c88489b703
34ecdd3e40c14c28b87bd45582ad493e
f759cba3e7bc4e5e84c3b683f899b295--34ecdd3e40c14c28b87bd45582ad493e
93a4697f9d43442f89d02ab2fa7848f4
34ecdd3e40c14c28b87bd45582ad493e--93a4697f9d43442f89d02ab2fa7848f4
93a4697f9d43442f89d02ab2fa7848f4--c5fbdbb7df604da5a075a0173402d14c
Training
Then we can set up the training part:
opt = torch . optim . Adam ( model . parameters (), lr = lr )
criterion = nn . CrossEntropyLoss ()
def cross_entropy ( model : nn . Module , data : Tensor ) -> tuple [ Tensor , dict ]:
x , y = data
out = model ( x )
loss = criterion ( out , y )
return loss , {}
train_config = TrainConfig ( max_iter = n_epochs , print_every = 10 , create_subfolder_per_run = True )
Trainer . set_use_grad ( True )
trainer = Trainer ( model = model , optimizer = opt , config = train_config , loss_fn = cross_entropy )
res_train = trainer . fit ( dataloader )
Inference
Finally, we can apply our model on the test set and check the score.
X_test , y_test = dataset . X_test , dataset . y_test
preds_test = torch . argmax ( torch . softmax ( model ( X_test ), dim = 1 ), dim = 1 )
accuracy_test = ( preds_test == y_test ) . type ( torch . float32 ) . mean ()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005