import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
We first implement a linear autoencoder
.
The LinearAE class construction receives two values:
Each encoder saves two inner members:
class LinearAE(nn.Module):
def __init__(self, d_in, d_enc):
super().__init__()
self.encoder = nn.Linear(d_in, d_enc)
self.decoder = nn.Linear(d_enc, d_in)
def forward(self, x):
return self.decoder(self.encoder(x))
Next, we construct a NN-based autoencoder, using the following architecture:

All layers are linear, but note the dimensions in each layer, as well as the activation function (ReLU/Tanh).
The NetworkAE class construction receives four values:
Each encoder saves two inner objects:
class NetworkAE(nn.Module):
def __init__(self, d_in, d_enc, d1, d2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(d_in, d1),
nn.Tanh(),
nn.Linear(d1, d2),
nn.ReLU(),
nn.Linear(d2, d1),
nn.ReLU(),
nn.Linear(d1, d_enc))
self.decoder = nn.Sequential(nn.Linear(d_enc, d1),
nn.ReLU(),
nn.Linear(d1, d2),
nn.ReLU(),
nn.Linear(d2, d1),
nn.Tanh(),
nn.Linear(d1, d_in))
def forward(self, x):
return self.decoder(self.encoder(x))
Next, you implement a function that performs the training for a given network.
The input parameters are:
def train_nn(NN, X, learning_rate, max_iter, print_iter=100):
optimiser = torch.optim.SGD(NN.parameters(), lr = learning_rate)
loss_fn = torch.nn.MSELoss(reduction='sum')
for epoch in range(max_iter):
X_pred = NN(X)
loss = loss_fn(X_pred, X)
if (epoch+1) % print_iter == 0:
print(f'Epoch: {epoch+1}, Loss: {loss.item()}')
optimiser.zero_grad()
loss.backward()
optimiser.step()
return NN
To test these encoders, we use the two-moon dataset.
SEED = 210711042
N = 500
SIG = 0.1
X, _ = make_moons(n_samples=N, noise=SIG, random_state=SEED)
plt.scatter(X[:,0],X[:,1])
plt.axis('equal');
XT = torch.tensor(X, dtype=torch.float32)
Our goal is to find a suitable 1-dimensional representation of this 2-dimensional data.
linenc = LinearAE(2, 1)
linenc = train_nn(linenc, XT, learning_rate=5e-04, max_iter=1000)
Epoch: 100, Loss: 96.56216430664062 Epoch: 200, Loss: 96.56212615966797 Epoch: 300, Loss: 96.5621337890625 Epoch: 400, Loss: 96.5621337890625 Epoch: 500, Loss: 96.5621337890625 Epoch: 600, Loss: 96.5621337890625 Epoch: 700, Loss: 96.5621337890625 Epoch: 800, Loss: 96.5621337890625 Epoch: 900, Loss: 96.5621337890625 Epoch: 1000, Loss: 96.5621337890625
netenc = NetworkAE(d_in=2, d_enc=1, d1=10, d2=25)
netenc = train_nn(netenc, XT, learning_rate=5e-04, max_iter=10000)
Epoch: 100, Loss: 145.1793975830078 Epoch: 200, Loss: 109.06774139404297 Epoch: 300, Loss: 54.61616516113281 Epoch: 400, Loss: 57.13446807861328 Epoch: 500, Loss: 44.31135559082031 Epoch: 600, Loss: 50.07269287109375 Epoch: 700, Loss: 48.18919372558594 Epoch: 800, Loss: 35.7377815246582 Epoch: 900, Loss: 35.59989547729492 Epoch: 1000, Loss: 38.22804260253906 Epoch: 1100, Loss: 33.766075134277344 Epoch: 1200, Loss: 35.52913284301758 Epoch: 1300, Loss: 31.04424285888672 Epoch: 1400, Loss: 29.961517333984375 Epoch: 1500, Loss: 35.26736831665039 Epoch: 1600, Loss: 28.458171844482422 Epoch: 1700, Loss: 28.232315063476562 Epoch: 1800, Loss: 25.74853515625 Epoch: 1900, Loss: 24.69481086730957 Epoch: 2000, Loss: 23.363853454589844 Epoch: 2100, Loss: 23.7608585357666 Epoch: 2200, Loss: 26.50672149658203 Epoch: 2300, Loss: 21.805360794067383 Epoch: 2400, Loss: 19.22904396057129 Epoch: 2500, Loss: 22.11136245727539 Epoch: 2600, Loss: 17.781848907470703 Epoch: 2700, Loss: 17.536663055419922 Epoch: 2800, Loss: 19.00665283203125 Epoch: 2900, Loss: 15.53482723236084 Epoch: 3000, Loss: 16.21259117126465 Epoch: 3100, Loss: 14.71164321899414 Epoch: 3200, Loss: 14.423664093017578 Epoch: 3300, Loss: 15.690651893615723 Epoch: 3400, Loss: 14.001856803894043 Epoch: 3500, Loss: 14.527196884155273 Epoch: 3600, Loss: 13.062973976135254 Epoch: 3700, Loss: 20.513336181640625 Epoch: 3800, Loss: 12.508989334106445 Epoch: 3900, Loss: 12.707592964172363 Epoch: 4000, Loss: 13.028851509094238 Epoch: 4100, Loss: 12.302831649780273 Epoch: 4200, Loss: 10.925149917602539 Epoch: 4300, Loss: 11.553437232971191 Epoch: 4400, Loss: 12.0352783203125 Epoch: 4500, Loss: 12.72936725616455 Epoch: 4600, Loss: 14.173344612121582 Epoch: 4700, Loss: 9.747336387634277 Epoch: 4800, Loss: 16.544492721557617 Epoch: 4900, Loss: 9.980114936828613 Epoch: 5000, Loss: 10.21574592590332 Epoch: 5100, Loss: 8.65145206451416 Epoch: 5200, Loss: 12.442646026611328 Epoch: 5300, Loss: 9.094295501708984 Epoch: 5400, Loss: 11.957836151123047 Epoch: 5500, Loss: 8.390283584594727 Epoch: 5600, Loss: 9.559822082519531 Epoch: 5700, Loss: 13.693891525268555 Epoch: 5800, Loss: 8.577988624572754 Epoch: 5900, Loss: 9.467506408691406 Epoch: 6000, Loss: 10.427685737609863 Epoch: 6100, Loss: 9.876237869262695 Epoch: 6200, Loss: 10.809561729431152 Epoch: 6300, Loss: 9.129270553588867 Epoch: 6400, Loss: 9.09569263458252 Epoch: 6500, Loss: 8.263495445251465 Epoch: 6600, Loss: 10.207550048828125 Epoch: 6700, Loss: 10.324889183044434 Epoch: 6800, Loss: 9.295036315917969 Epoch: 6900, Loss: 9.337799072265625 Epoch: 7000, Loss: 9.53169059753418 Epoch: 7100, Loss: 10.010257720947266 Epoch: 7200, Loss: 8.871477127075195 Epoch: 7300, Loss: 8.178281784057617 Epoch: 7400, Loss: 8.216232299804688 Epoch: 7500, Loss: 7.573520660400391 Epoch: 7600, Loss: 8.55801773071289 Epoch: 7700, Loss: 8.71977424621582 Epoch: 7800, Loss: 8.085667610168457 Epoch: 7900, Loss: 8.584272384643555 Epoch: 8000, Loss: 9.627182006835938 Epoch: 8100, Loss: 8.729002952575684 Epoch: 8200, Loss: 8.552417755126953 Epoch: 8300, Loss: 8.67764663696289 Epoch: 8400, Loss: 7.191255569458008 Epoch: 8500, Loss: 7.951998233795166 Epoch: 8600, Loss: 9.69924545288086 Epoch: 8700, Loss: 14.038476943969727 Epoch: 8800, Loss: 9.188523292541504 Epoch: 8900, Loss: 7.946074485778809 Epoch: 9000, Loss: 9.36130142211914 Epoch: 9100, Loss: 7.463373184204102 Epoch: 9200, Loss: 8.972496032714844 Epoch: 9300, Loss: 8.121110916137695 Epoch: 9400, Loss: 7.438579559326172 Epoch: 9500, Loss: 7.733771324157715 Epoch: 9600, Loss: 9.84853744506836 Epoch: 9700, Loss: 7.881923198699951 Epoch: 9800, Loss: 11.179054260253906 Epoch: 9900, Loss: 7.129364490509033 Epoch: 10000, Loss: 8.220257759094238
For a comparison, we also implement a dimension reduction and reconstruction using principal component analysis.
from sklearn.preprocessing import StandardScaler
def pc_transform(X, PCS):
X = X.T
scaler = StandardScaler()
scaler.fit(X.T)
X = scaler.transform(X.T)
U, _, _ = np.linalg.svd(X, full_matrices=False)
result = np.dot((U[:,PCS]).T, X)
return scaler.inverse_transform(result).T
X_PCA = pc_transform(XT.T, [0, 1])
X_LIN = linenc.forward(XT).detach().numpy()
X_NET = netenc.forward(XT).detach().numpy()
Below demonstrates the effectiveness of the three methods.
plt.figure(figsize=(18,4))
plt.subplot(1,3,1)
plt.scatter(X[:,0], X[:,1])
plt.scatter(X_PCA[:,0], X_PCA[:,1])
plt.axis('equal');
plt.subplot(1,3,2)
plt.scatter(X[:,0], X[:,1])
plt.scatter(X_LIN[:,0], X_LIN[:,1])
plt.axis('equal');
plt.subplot(1,3,3)
plt.scatter(X[:,0], X[:,1])
plt.scatter(X_NET[:,0], X_NET[:,1])
plt.axis('equal');
We test performance of the NN-autencoder on new data from make_moons.
SEED = 9987
N = 100
X_new, L_new = make_moons(n_samples=N, noise=SIG, random_state=SEED)
X_newT = torch.tensor(X_new, dtype=torch.float32)
XR_new = netenc.forward(X_newT).detach().numpy()
XE_new = netenc.encoder(X_newT).detach().numpy()
The below cell shows the 1-dimensional encoded data on the left, and the decoded reconstruction on the right.
plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.scatter(XE_new,2*L_new, c =L_new)
plt.axis('equal');
plt.subplot(1,2,2)
plt.scatter(X_new[:,0],X_new[:,1])
plt.scatter(XR_new[:,0],XR_new[:,1])
plt.axis('equal');
We can see what the decoder is doing by applying it to latent encoded space.
linspace = torch.linspace(-20, 20, 1000)[:, None]
XR_gen = netenc.decoder(linspace).detach().numpy()
plt.scatter(XR_gen[:,0],XR_gen[:,1])
plt.axis('equal');