Classification with QNN
In this tutorial we will show how to use Qadence to solve a basic classification task using a hybrid quantum-classical model composed of a QNN and classical layers.
Dataset
We will use the Iris dataset separated into training and testing sets.
The task is to classify iris plants presented as a multivariate dataset of 4 features into 3 labels (Iris Setosa, Iris Versicolour, or Iris Virginica).
When applying machine learning models, and particularly neural networks, it is recommended to normalize the data. As such, we use a common StandardScaler (we transform the data \(x\) to \(z = (x - u) / s\) where \(u, s\) are respectively the mean and standard deviation of the training samples).
import random
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import DataLoader , Dataset
from qadence import QNN , RX , FeatureParameter , QuantumCircuit , Z , chain , hea , kron
from qadence.ml_tools import TrainConfig , Trainer
class IrisDataset ( Dataset ):
"""The Iris dataset split into a training set and a test set.
A StandardScaler is applied prior to applying models.
"""
def __init__ ( self ):
X , y = load_iris ( return_X_y = True )
X_train , X_test , y_train , y_test = train_test_split ( X , y , test_size = 0.33 , random_state = 42 )
self . scaler = StandardScaler ()
self . scaler . fit ( X_train )
self . X = torch . tensor ( self . scaler . transform ( X_train ), requires_grad = False )
self . y = torch . tensor ( y_train , requires_grad = False )
self . X_test = torch . tensor ( self . scaler . transform ( X_test ), requires_grad = False )
self . y_test = torch . tensor ( y_test , requires_grad = False )
def __getitem__ ( self , index ) -> tuple [ Tensor , Tensor ]:
return self . X [ index ], self . y [ index ]
def __len__ ( self ) -> int :
return len ( self . y )
n_features = 4 # sepal length, sepal width, petal length, petal width
n_layers = 3
n_neurons_final_linear_layer = 3
n_epochs = 1000
lr = 1e-1
dataset = IrisDataset ()
dataloader = DataLoader ( dataset , batch_size = 20 , shuffle = True )
Hybrid QNN
We set up the QNN part composed of multiple feature map layers, each followed by a variational layer.
The type of variational layer we use is the hardware-efficient-ansatz (HEA). You can check the qml constructors tutorial to see how you can customize these components.
The output will be the expectation value with respect to a \(Z\) observable on qubit \(0\) .
Then we add a simple linear layer serving as a classification head. This is equivalent to applying a weight matrix \(W\) and bias vector \(b\) to the output of the QNN denoted \(o\) , \(l = W * o + b\) . To obtain probabilities, we can apply the softmax function defined as: \(p_i = \exp(l_i) / \sum_{j=1}^3 \exp(l_i)\) .
Note softmax is not applied during training with the cross-entropy loss .
feature_parameters = [ FeatureParameter ( f "x_ { i } " ) for i in range ( n_features )]
fm_layer = RX ( 0 , feature_parameters [ 0 ])
for q in range ( 1 , n_features ):
fm_layer = kron ( fm_layer , RX ( q , feature_parameters [ q ]))
ansatz_layers = [
hea ( n_qubits = n_features , depth = 1 , param_prefix = f "theta_ { layer } " )
for layer in range ( n_layers )
]
blocks = chain ( fm_layer , ansatz_layers [ 0 ])
for layer in range ( 1 , n_layers ):
blocks = chain ( blocks , fm_layer , ansatz_layers [ layer ])
qc = QuantumCircuit ( n_features , blocks )
qnn = QNN ( circuit = qc , observable = Z ( 0 ), inputs = [ f "x_ { i } " for i in range ( n_features )])
model = nn . Sequential ( qnn , nn . Linear ( 1 , n_neurons_final_linear_layer ))
Below is a visualization of the QNN:
%3
cluster_8fcd288cca124756b9f3b6b63500edf8
HEA
cluster_5666bc7191b7487b989d2ebf24b16cf9
Obs.
cluster_2cbb695eaa4b4cf59e732ad896adc43c
HEA
cluster_690441ddf9f042b98236b8797e8d7614
HEA
de68d67795a74aa1b1ed31d16d376fa9
0
c896849c4f6d4a5ca56c9b135f8cbd08
RX(x₀)
de68d67795a74aa1b1ed31d16d376fa9--c896849c4f6d4a5ca56c9b135f8cbd08
bcd489b052664fea98238832fa8f8261
1
c72da977458f4f9cad5a81d99f45e21a
RX(theta₀₀)
c896849c4f6d4a5ca56c9b135f8cbd08--c72da977458f4f9cad5a81d99f45e21a
a30d5993777841a59dba55a97e5ca60c
RY(theta₀₄)
c72da977458f4f9cad5a81d99f45e21a--a30d5993777841a59dba55a97e5ca60c
4181ff1b46c445c4ae854ca7a0dec569
RX(theta₀₈)
a30d5993777841a59dba55a97e5ca60c--4181ff1b46c445c4ae854ca7a0dec569
106f4c71972b46cea4caadfa84e5d756
4181ff1b46c445c4ae854ca7a0dec569--106f4c71972b46cea4caadfa84e5d756
b8bc42e0869c488695d8cffe31d140d4
106f4c71972b46cea4caadfa84e5d756--b8bc42e0869c488695d8cffe31d140d4
dc605fa49c3641df913391c483728d38
RX(x₀)
b8bc42e0869c488695d8cffe31d140d4--dc605fa49c3641df913391c483728d38
fa06f891fe924b5083b1b7ed3a7daea8
RX(theta₁₀)
dc605fa49c3641df913391c483728d38--fa06f891fe924b5083b1b7ed3a7daea8
01ed9c23de964b3aa1bca2522bee4de0
RY(theta₁₄)
fa06f891fe924b5083b1b7ed3a7daea8--01ed9c23de964b3aa1bca2522bee4de0
25eeda31b2984ae1a8012d28347b847b
RX(theta₁₈)
01ed9c23de964b3aa1bca2522bee4de0--25eeda31b2984ae1a8012d28347b847b
ae4c7cd38bdf49de8da34ab90c21d580
25eeda31b2984ae1a8012d28347b847b--ae4c7cd38bdf49de8da34ab90c21d580
cfbb2103d0f748e9877d74d62c56625f
ae4c7cd38bdf49de8da34ab90c21d580--cfbb2103d0f748e9877d74d62c56625f
025612bc71e74df1941f6576b85cf0d8
RX(x₀)
cfbb2103d0f748e9877d74d62c56625f--025612bc71e74df1941f6576b85cf0d8
06f90a14c3f64bb686699726faaac584
RX(theta₂₀)
025612bc71e74df1941f6576b85cf0d8--06f90a14c3f64bb686699726faaac584
bcfed19f1eab4cf4baf484ffd42a9e9a
RY(theta₂₄)
06f90a14c3f64bb686699726faaac584--bcfed19f1eab4cf4baf484ffd42a9e9a
70456408e37548f480b115011307557a
RX(theta₂₈)
bcfed19f1eab4cf4baf484ffd42a9e9a--70456408e37548f480b115011307557a
b7d30ce450204e16bc5fc086675646f4
70456408e37548f480b115011307557a--b7d30ce450204e16bc5fc086675646f4
bc5ea024e44e442196b7e2227499a281
b7d30ce450204e16bc5fc086675646f4--bc5ea024e44e442196b7e2227499a281
2770fc3cfd3b4feeb5cd2dd8113a2500
Z
bc5ea024e44e442196b7e2227499a281--2770fc3cfd3b4feeb5cd2dd8113a2500
b11d81bfc7484ef588c0fd236b6b84f5
2770fc3cfd3b4feeb5cd2dd8113a2500--b11d81bfc7484ef588c0fd236b6b84f5
d2ce00a8fbd04248816f2f08fcf92933
a990b5a20d474ad09d77157fa3ab0732
RX(x₁)
bcd489b052664fea98238832fa8f8261--a990b5a20d474ad09d77157fa3ab0732
9a4e0ab0b2814b22af3272fb285c153b
2
dbc5df7b3f6c47048646d6f3da6b6e92
RX(theta₀₁)
a990b5a20d474ad09d77157fa3ab0732--dbc5df7b3f6c47048646d6f3da6b6e92
945cb00e10b444fb81820164894b03f2
RY(theta₀₅)
dbc5df7b3f6c47048646d6f3da6b6e92--945cb00e10b444fb81820164894b03f2
16fd9b8f38fc420687858f527f8f67b6
RX(theta₀₉)
945cb00e10b444fb81820164894b03f2--16fd9b8f38fc420687858f527f8f67b6
a5375a3195484c23ab1a843e31edb846
X
16fd9b8f38fc420687858f527f8f67b6--a5375a3195484c23ab1a843e31edb846
a5375a3195484c23ab1a843e31edb846--106f4c71972b46cea4caadfa84e5d756
ed8f20868f2d483ca95c22ec7ef23d65
a5375a3195484c23ab1a843e31edb846--ed8f20868f2d483ca95c22ec7ef23d65
be0e4af36f4648699222104567d31ecb
RX(x₁)
ed8f20868f2d483ca95c22ec7ef23d65--be0e4af36f4648699222104567d31ecb
e95e72dd9a14419697d00f3e2e58f9e2
RX(theta₁₁)
be0e4af36f4648699222104567d31ecb--e95e72dd9a14419697d00f3e2e58f9e2
1c157527831841058d989e65e9d5cf1c
RY(theta₁₅)
e95e72dd9a14419697d00f3e2e58f9e2--1c157527831841058d989e65e9d5cf1c
f6f4db406e5d44f794d726f618c39b25
RX(theta₁₉)
1c157527831841058d989e65e9d5cf1c--f6f4db406e5d44f794d726f618c39b25
346f357a149a40a5b96162e4251bbb78
X
f6f4db406e5d44f794d726f618c39b25--346f357a149a40a5b96162e4251bbb78
346f357a149a40a5b96162e4251bbb78--ae4c7cd38bdf49de8da34ab90c21d580
12ef1af0b16c40da9839c73d53017662
346f357a149a40a5b96162e4251bbb78--12ef1af0b16c40da9839c73d53017662
2384da870e65493db0396db9972ee0f2
RX(x₁)
12ef1af0b16c40da9839c73d53017662--2384da870e65493db0396db9972ee0f2
04a9b3e8ef5d484b8fc5d525959cf0a0
RX(theta₂₁)
2384da870e65493db0396db9972ee0f2--04a9b3e8ef5d484b8fc5d525959cf0a0
78a526b9a1274366a167dfdf7496bd0a
RY(theta₂₅)
04a9b3e8ef5d484b8fc5d525959cf0a0--78a526b9a1274366a167dfdf7496bd0a
aa96579f744c4a31999989fd0d7fd4ec
RX(theta₂₉)
78a526b9a1274366a167dfdf7496bd0a--aa96579f744c4a31999989fd0d7fd4ec
5244f680655945318507d8c9f9f68d02
X
aa96579f744c4a31999989fd0d7fd4ec--5244f680655945318507d8c9f9f68d02
5244f680655945318507d8c9f9f68d02--b7d30ce450204e16bc5fc086675646f4
960688474e974410ba0a073465a7bf4d
5244f680655945318507d8c9f9f68d02--960688474e974410ba0a073465a7bf4d
21815e3f9fc4429e87fe68ce59d95e8f
960688474e974410ba0a073465a7bf4d--21815e3f9fc4429e87fe68ce59d95e8f
21815e3f9fc4429e87fe68ce59d95e8f--d2ce00a8fbd04248816f2f08fcf92933
4e2705e9e2464844bc9fdced97bda8aa
466d7f064d02492ea876d8ef3378fb4d
RX(x₂)
9a4e0ab0b2814b22af3272fb285c153b--466d7f064d02492ea876d8ef3378fb4d
d7cd9af93e87469aa12b799e41005e9a
3
10cc4ccd3044452d8d19c75a7ff33715
RX(theta₀₂)
466d7f064d02492ea876d8ef3378fb4d--10cc4ccd3044452d8d19c75a7ff33715
d86e6891142f460e8ea6bdc41242d4ac
RY(theta₀₆)
10cc4ccd3044452d8d19c75a7ff33715--d86e6891142f460e8ea6bdc41242d4ac
279d7dc7b3ec4cdba1eb7fa8ec66fa3d
RX(theta₀₁₀)
d86e6891142f460e8ea6bdc41242d4ac--279d7dc7b3ec4cdba1eb7fa8ec66fa3d
1261ca806bfd437c86273c8eceb2f0db
279d7dc7b3ec4cdba1eb7fa8ec66fa3d--1261ca806bfd437c86273c8eceb2f0db
deac9f0e5d294cf791ca98b6220be11b
X
1261ca806bfd437c86273c8eceb2f0db--deac9f0e5d294cf791ca98b6220be11b
deac9f0e5d294cf791ca98b6220be11b--ed8f20868f2d483ca95c22ec7ef23d65
98a50295980142c081b0d79c2ec6ca36
RX(x₂)
deac9f0e5d294cf791ca98b6220be11b--98a50295980142c081b0d79c2ec6ca36
361e8ab63d46475abedd0cac77f297bd
RX(theta₁₂)
98a50295980142c081b0d79c2ec6ca36--361e8ab63d46475abedd0cac77f297bd
9b5130bce0c34f3fbe59cb070e16cc61
RY(theta₁₆)
361e8ab63d46475abedd0cac77f297bd--9b5130bce0c34f3fbe59cb070e16cc61
4e795de2d21e48eabba317afad43dd11
RX(theta₁₁₀)
9b5130bce0c34f3fbe59cb070e16cc61--4e795de2d21e48eabba317afad43dd11
f2f25f26760345ad8b433a65d845e000
4e795de2d21e48eabba317afad43dd11--f2f25f26760345ad8b433a65d845e000
5dfcc9bfb4124479bf40c878f4ccd6e9
X
f2f25f26760345ad8b433a65d845e000--5dfcc9bfb4124479bf40c878f4ccd6e9
5dfcc9bfb4124479bf40c878f4ccd6e9--12ef1af0b16c40da9839c73d53017662
2a2cb7b1831443c0a29714a82e8ff278
RX(x₂)
5dfcc9bfb4124479bf40c878f4ccd6e9--2a2cb7b1831443c0a29714a82e8ff278
5d9d47edb2a6481596c575220c62b345
RX(theta₂₂)
2a2cb7b1831443c0a29714a82e8ff278--5d9d47edb2a6481596c575220c62b345
daaacfabfdae4da086a2327c487c6887
RY(theta₂₆)
5d9d47edb2a6481596c575220c62b345--daaacfabfdae4da086a2327c487c6887
8674421fe19142039f55eeaf113d8b49
RX(theta₂₁₀)
daaacfabfdae4da086a2327c487c6887--8674421fe19142039f55eeaf113d8b49
cfba99c17f2841c4a7f53504fff36ebc
8674421fe19142039f55eeaf113d8b49--cfba99c17f2841c4a7f53504fff36ebc
614455691c9b41648c178f3cbb706759
X
cfba99c17f2841c4a7f53504fff36ebc--614455691c9b41648c178f3cbb706759
614455691c9b41648c178f3cbb706759--960688474e974410ba0a073465a7bf4d
c51c8c2b139640638c9c4022c9363f50
614455691c9b41648c178f3cbb706759--c51c8c2b139640638c9c4022c9363f50
c51c8c2b139640638c9c4022c9363f50--4e2705e9e2464844bc9fdced97bda8aa
db0a6ccf8809490c9184445322ed8159
e7603871a9a04fcd9db6c989c3e41b7f
RX(x₃)
d7cd9af93e87469aa12b799e41005e9a--e7603871a9a04fcd9db6c989c3e41b7f
0db5f872996048f48405a41d948243f1
RX(theta₀₃)
e7603871a9a04fcd9db6c989c3e41b7f--0db5f872996048f48405a41d948243f1
96cc1adf86514994a5eab0493b504f4c
RY(theta₀₇)
0db5f872996048f48405a41d948243f1--96cc1adf86514994a5eab0493b504f4c
0b6c3010f2c94ef5a1f11388462a8e7c
RX(theta₀₁₁)
96cc1adf86514994a5eab0493b504f4c--0b6c3010f2c94ef5a1f11388462a8e7c
aa93afb20e654b0fa84d856ade6c1101
X
0b6c3010f2c94ef5a1f11388462a8e7c--aa93afb20e654b0fa84d856ade6c1101
aa93afb20e654b0fa84d856ade6c1101--1261ca806bfd437c86273c8eceb2f0db
22bf515062e64f469772b1675fcc4572
aa93afb20e654b0fa84d856ade6c1101--22bf515062e64f469772b1675fcc4572
52bea60b596e4f1595ef3002ab7b8d20
RX(x₃)
22bf515062e64f469772b1675fcc4572--52bea60b596e4f1595ef3002ab7b8d20
c2de3d93282e4e7ca2678ec326da02d6
RX(theta₁₃)
52bea60b596e4f1595ef3002ab7b8d20--c2de3d93282e4e7ca2678ec326da02d6
03f0bdbd232746bea178a131b5c155a4
RY(theta₁₇)
c2de3d93282e4e7ca2678ec326da02d6--03f0bdbd232746bea178a131b5c155a4
3808d4a958914b3f99b244217342e314
RX(theta₁₁₁)
03f0bdbd232746bea178a131b5c155a4--3808d4a958914b3f99b244217342e314
0d6221570a8847b59c7291bfac83a4ca
X
3808d4a958914b3f99b244217342e314--0d6221570a8847b59c7291bfac83a4ca
0d6221570a8847b59c7291bfac83a4ca--f2f25f26760345ad8b433a65d845e000
e1da5a369c7f4e639fdc2097961d5066
0d6221570a8847b59c7291bfac83a4ca--e1da5a369c7f4e639fdc2097961d5066
3ce8d1fe357b48808df47f20dd5e9e55
RX(x₃)
e1da5a369c7f4e639fdc2097961d5066--3ce8d1fe357b48808df47f20dd5e9e55
302dab50d7734c968cf4f39a1f429c7b
RX(theta₂₃)
3ce8d1fe357b48808df47f20dd5e9e55--302dab50d7734c968cf4f39a1f429c7b
005ab65f823c40fcaa8ad8d4f5c4a093
RY(theta₂₇)
302dab50d7734c968cf4f39a1f429c7b--005ab65f823c40fcaa8ad8d4f5c4a093
a0b8a3f3bf21448fbb59b86d2a616c60
RX(theta₂₁₁)
005ab65f823c40fcaa8ad8d4f5c4a093--a0b8a3f3bf21448fbb59b86d2a616c60
d5ef99affd064d3f9a2c436cb15b99c7
X
a0b8a3f3bf21448fbb59b86d2a616c60--d5ef99affd064d3f9a2c436cb15b99c7
d5ef99affd064d3f9a2c436cb15b99c7--cfba99c17f2841c4a7f53504fff36ebc
5575f9aaa4104db09cb2735dd17fed91
d5ef99affd064d3f9a2c436cb15b99c7--5575f9aaa4104db09cb2735dd17fed91
93e8d6ad3c744fcd81194b2073880ebf
5575f9aaa4104db09cb2735dd17fed91--93e8d6ad3c744fcd81194b2073880ebf
93e8d6ad3c744fcd81194b2073880ebf--db0a6ccf8809490c9184445322ed8159
Training
Then we can set up the training part:
opt = torch . optim . Adam ( model . parameters (), lr = lr )
criterion = nn . CrossEntropyLoss ()
def cross_entropy ( model : nn . Module , data : Tensor ) -> tuple [ Tensor , dict ]:
x , y = data
out = model ( x )
loss = criterion ( out , y )
return loss , {}
train_config = TrainConfig ( max_iter = n_epochs , print_every = 10 , create_subfolder_per_run = True )
Trainer . set_use_grad ( True )
trainer = Trainer ( model = model , optimizer = opt , config = train_config , loss_fn = cross_entropy )
res_train = trainer . fit ( dataloader )
Inference
Finally, we can apply our model on the test set and check the score.
X_test , y_test = dataset . X_test , dataset . y_test
preds_test = torch . argmax ( torch . softmax ( model ( X_test ), dim = 1 ), dim = 1 )
accuracy_test = ( preds_test == y_test ) . type ( torch . float32 ) . mean ()
## Should reach higher than 0.9
Test Accuracy: 0.9200000166893005