logistic-regression
logistic-regression
Imports
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
np.random.seed(123)
idx = np.arange(y.shape[0])
np.random.shuffle(idx)
X_test, y_test = x[idx[:25]], y[idx[:25]]
X_train, y_train = x[idx[25:]], y[idx[25:]]
mu, std = np.mean(X_train, axis=0), np.std(X_train, axis=0)
X_train, X_test = (X_train - mu) / std, (X_test - mu) / std
class LogisticRegression1():
def __init__(self, num_features):
self.num_features = num_features
self.weights = torch.zeros(1, num_features,
dtype=torch.float32, device=device)
self.bias = torch.zeros(1, dtype=torch.float32, device=device)
model1 = LogisticRegression1(num_features=2)
epoch_cost = model1.train(X_train_tensor, y_train_tensor, num_epochs=30, learning_rate=0.1)
print('\nModel parameters:')
print(' Weights: %s' % model1.weights)
print(' Bias: %s' % model1.bias)
plt.plot(epoch_cost)
plt.ylabel('Neg. Log Likelihood Loss')
plt.xlabel('Epoch')
plt.show()
##########################
### 2D Decision Boundary
##########################
w, b = model1.weights.view(-1), model1.bias
x_min = -2
y_min = ( (-(w[0] * x_min) - b[0])
/ w[1] )
x_max = 2
y_max = ( (-(w[0] * x_max) - b[0])
/ w[1] )
ax[1].legend(loc='upper left')
plt.show()
self.linear.weight.detach().zero_()
self.linear.bias.detach().zero_()
# Note: the trailing underscore
# means "in-place operation" in the context
# of PyTorch
model2 = LogisticRegression2(num_features=2).to(device)
optimizer = torch.optim.SGD(model2.parameters(), lr=0.1)
num_epochs = 30
print('\nModel parameters:')
print(' Weights: %s' % model2.linear.weight)
print(' Bias: %s' % model2.linear.bias)
pred_probas = model2(X_test_tensor)
test_acc = comp_accuracy(y_test_tensor, pred_probas)
##########################
### 2D Decision Boundary
##########################
w, b = model2.linear.weight.detach().view(-1), model2.linear.bias.detach()
x_min = -2
y_min = ( (-(w[0] * x_min) - b[0])
/ w[1] )
x_max = 2
y_max = ( (-(w[0] * x_max) - b[0])
/ w[1] )
ax[1].legend(loc='upper left')
plt.show()
Loading [MathJax]/jax/output/CommonHTML/fonts/TeX/fontdata.js