VGG for cifar10¶

original code is =>¶

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

In [31]:
import torch
import torch.nn as nn

import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
In [32]:
import visdom

vis = visdom.Visdom()
vis.close(env="main")
Setting up a new session...
Out[32]:
''

define loss tracker¶

In [33]:
def loss_tracker(loss_plot, loss_value, num):
    '''num, loss_value, are Tensor'''
    vis.line(X=num,
             Y=loss_value,
             win = loss_plot,
             update='append'
             )
In [34]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)
In [35]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False,
                                       download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified
In [36]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)
vis.images(images/2 + 0.5)

# show images
#imshow(torchvision.utils.make_grid(images))

# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
truck   car   car   cat

make VGG16 using vgg.py¶

In [17]:
import torchvision.models.vgg as vgg
In [18]:
cfg = [32,32,'M', 64,64,128,128,128,'M',256,256,256,512,512,512,'M'] #13 + 3 =vgg16
In [19]:
class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        #self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        #x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
In [20]:
vgg16= VGG(vgg.make_layers(cfg),10,True).to(device)
In [21]:
a=torch.Tensor(1,3,32,32).to(device)
out = vgg16(a)
print(out)
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
In [22]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(vgg16.parameters(), lr = 0.005,momentum=0.9)

lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

make plot¶

In [23]:
loss_plt = vis.line(Y=torch.Tensor(1).zero_(),opts=dict(title='loss_tracker', legend=['loss'], showlegend=True))

training¶

In [24]:
print(len(trainloader))
epochs = 50

for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    lr_sche.step()
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = vgg16(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 30 == 29:    # print every 30 mini-batches
            loss_tracker(loss_plt, torch.Tensor([running_loss/30]), torch.Tensor([i + epoch*len(trainloader) ]))
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 30))
            running_loss = 0.0
        

print('Finished Training')
98
/opt/conda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:136: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
[1,    30] loss: 2.302
[1,    60] loss: 2.299
[1,    90] loss: 2.293
[2,    30] loss: 2.245
[2,    60] loss: 2.137
[2,    90] loss: 2.018
[3,    30] loss: 1.894
[3,    60] loss: 1.815
[3,    90] loss: 1.739
[4,    30] loss: 1.680
[4,    60] loss: 1.628
[4,    90] loss: 1.605
[5,    30] loss: 1.565
[5,    60] loss: 1.510
[5,    90] loss: 1.476
[6,    30] loss: 1.426
[6,    60] loss: 1.398
[6,    90] loss: 1.394
[7,    30] loss: 1.331
[7,    60] loss: 1.287
[7,    90] loss: 1.269
[8,    30] loss: 1.249
[8,    60] loss: 1.213
[8,    90] loss: 1.184
[9,    30] loss: 1.159
[9,    60] loss: 1.124
[9,    90] loss: 1.129
[10,    30] loss: 1.074
[10,    60] loss: 1.047
[10,    90] loss: 1.057
[11,    30] loss: 1.008
[11,    60] loss: 0.997
[11,    90] loss: 0.986
[12,    30] loss: 0.933
[12,    60] loss: 0.965
[12,    90] loss: 0.923
[13,    30] loss: 0.913
[13,    60] loss: 0.866
[13,    90] loss: 0.899
[14,    30] loss: 0.865
[14,    60] loss: 0.843
[14,    90] loss: 0.853
[15,    30] loss: 0.790
[15,    60] loss: 0.792
[15,    90] loss: 0.825
[16,    30] loss: 0.762
[16,    60] loss: 0.758
[16,    90] loss: 0.742
[17,    30] loss: 0.727
[17,    60] loss: 0.706
[17,    90] loss: 0.726
[18,    30] loss: 0.709
[18,    60] loss: 0.692
[18,    90] loss: 0.696
[19,    30] loss: 0.679
[19,    60] loss: 0.657
[19,    90] loss: 0.665
[20,    30] loss: 0.601
[20,    60] loss: 0.621
[20,    90] loss: 0.635
[21,    30] loss: 0.569
[21,    60] loss: 0.579
[21,    90] loss: 0.561
[22,    30] loss: 0.533
[22,    60] loss: 0.533
[22,    90] loss: 0.559
[23,    30] loss: 0.515
[23,    60] loss: 0.520
[23,    90] loss: 0.521
[24,    30] loss: 0.462
[24,    60] loss: 0.483
[24,    90] loss: 0.469
[25,    30] loss: 0.420
[25,    60] loss: 0.426
[25,    90] loss: 0.456
[26,    30] loss: 0.397
[26,    60] loss: 0.391
[26,    90] loss: 0.401
[27,    30] loss: 0.365
[27,    60] loss: 0.356
[27,    90] loss: 0.373
[28,    30] loss: 0.330
[28,    60] loss: 0.350
[28,    90] loss: 0.348
[29,    30] loss: 0.309
[29,    60] loss: 0.296
[29,    90] loss: 0.322
[30,    30] loss: 0.268
[30,    60] loss: 0.255
[30,    90] loss: 0.276
[31,    30] loss: 0.223
[31,    60] loss: 0.224
[31,    90] loss: 0.239
[32,    30] loss: 0.194
[32,    60] loss: 0.213
[32,    90] loss: 0.218
[33,    30] loss: 0.179
[33,    60] loss: 0.180
[33,    90] loss: 0.202
[34,    30] loss: 0.166
[34,    60] loss: 0.169
[34,    90] loss: 0.172
[35,    30] loss: 0.147
[35,    60] loss: 0.139
[35,    90] loss: 0.149
[36,    30] loss: 0.117
[36,    60] loss: 0.129
[36,    90] loss: 0.124
[37,    30] loss: 0.105
[37,    60] loss: 0.110
[37,    90] loss: 0.110
[38,    30] loss: 0.104
[38,    60] loss: 0.094
[38,    90] loss: 0.108
[39,    30] loss: 0.092
[39,    60] loss: 0.086
[39,    90] loss: 0.091
[40,    30] loss: 0.076
[40,    60] loss: 0.076
[40,    90] loss: 0.065
[41,    30] loss: 0.062
[41,    60] loss: 0.065
[41,    90] loss: 0.069
[42,    30] loss: 0.065
[42,    60] loss: 0.057
[42,    90] loss: 0.067
[43,    30] loss: 0.055
[43,    60] loss: 0.056
[43,    90] loss: 0.051
[44,    30] loss: 0.048
[44,    60] loss: 0.048
[44,    90] loss: 0.053
[45,    30] loss: 0.042
[45,    60] loss: 0.037
[45,    90] loss: 0.044
[46,    30] loss: 0.033
[46,    60] loss: 0.039
[46,    90] loss: 0.042
[47,    30] loss: 0.031
[47,    60] loss: 0.034
[47,    90] loss: 0.040
[48,    30] loss: 0.034
[48,    60] loss: 0.030
[48,    90] loss: 0.030
[49,    30] loss: 0.025
[49,    60] loss: 0.030
[49,    90] loss: 0.029
[50,    30] loss: 0.026
[50,    60] loss: 0.025
[50,    90] loss: 0.029
Finished Training
In [25]:
dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
GroundTruth:    cat  ship  ship plane
In [26]:
outputs = vgg16(images.to(device))
In [27]:
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))
Predicted:    cat  ship  ship plane
In [28]:
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = vgg16(images)
        
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
Accuracy of the network on the 10000 test images: 76 %