24.md 2.4 KB
Newer Older
W
wizardforcel 已提交
1
# 5.2 – GPU 加速运算
W
wizardforcel 已提交
2 3 4 5 6 7 8 9 10 11

在 GPU 训练可以大幅提升运算速度. 而且 Torch 也有一套很好的 GPU 运算体系. 但是要强调的是:

*   你的电脑里有合适的 GPU 显卡(NVIDIA), 且支持 CUDA 模块. [请在NVIDIA官网查询](https://www.pytorchtutorial.com/goto/https://developer.nvidia.com/cuda-gpus)
*   必须安装 GPU 版的 Torch, [点击这里查看如何安装](https://www.pytorchtutorial.com/1-2-install-pytorch/)

## 用 GPU 训练 CNN

这份 GPU 的代码是依据[之前这份CNN](https://www.pytorchtutorial.com/goto/https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py)的代码修改的. 大概修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 CNN 也变成 GPU 能读的形式. 做法就是在后面加上 .cuda() , 很简单.

W
wizardforcel 已提交
12
```py
W
wizardforcel 已提交
13 14 15 16 17 18 19 20 21 22 23
...

test_data = torchvision.datasets.MNIST(root=\'./mnist/\', train=False)

# !!!!!!!! 修改 test data 形式 !!!!!!!!! #
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)).type(torch.FloatTensor)[:2000].cuda()/255\.   # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
```

再来把我们的 CNN 参数也变成 GPU 兼容形式.

W
wizardforcel 已提交
24
```py
W
wizardforcel 已提交
25 26 27 28 29 30 31 32 33 34 35
class CNN(nn.Module):
    ...

cnn = CNN()

# !!!!!!!! 转换 cnn 去 CUDA !!!!!!!!! #
cnn.cuda()      # Moves all model parameters and buffers to the GPU.
```

然后就是在 train 的时候, 将每次的training data 变成 GPU 形式. .cuda()

W
wizardforcel 已提交
36
```py
W
wizardforcel 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
for epoch ..:
    for step, ...:
        # !!!!!!!! 这里有修改 !!!!!!!!! #
        b_x = Variable(x).cuda()    # Tensor on GPU
        b_y = Variable(y).cuda()    # Tensor on GPU

        ...

        if step % 50 == 0:
            test_output = cnn(test_x)

            # !!!!!!!! 这里有修改  !!!!!!!!! #
            pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze()  # 将操作放去 GPU

            accuracy = torch.sum(pred_y == test_y) / test_y.size(0)
            ...

test_output = cnn(test_x[:10])

# !!!!!!!! 这里有修改 !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze()  # 将操作放去 GPU
...
print(test_y[:10], \'real number\')
```

大功告成~

所以这也就是在我 [github 代码](https://www.pytorchtutorial.com/goto/https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/502_GPU.py) 中的每一步的意义啦.

文章来源:[莫烦](https://www.pytorchtutorial.com/goto/https://morvanzhou.github.io/)