diff --git a/doc/fluid/user_guides/howto/dygraph/DyGraph.md b/doc/fluid/user_guides/howto/dygraph/DyGraph.md index 5fad3b442d4df11ad38b1f512b4ba1429d81b203..f3bc247c1cb7e2da40e26b4a6a15751376105d97 100644 --- a/doc/fluid/user_guides/howto/dygraph/DyGraph.md +++ b/doc/fluid/user_guides/howto/dygraph/DyGraph.md @@ -394,6 +394,142 @@ Dygraph将非常适合和Numpy一起使用,使用`fluid.dygraph.to_variable(x) 在使用`fluid.dygraph.guard()`时可以通过传入`fluid.CUDAPlace(0)`或者`fluid.CPUPlace()`来选择执行DyGraph的设备,通常如果不做任何处理将会自动适配您的设备。 +## 使用多卡训练模型 + +目前PaddlePaddle支持通过多进程方式进行多卡训练,即每个进程对应一张卡。训练过程中,在第一次执行前向操作时,如果该操作需要参数,则会将0号卡的参数Broadcast到其他卡上,确保各个卡上的参数一致;在计算完反向操作之后,将产生的参数梯度在所有卡之间进行聚合;最后在各个GPU卡上分别进行参数更新。 + + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + with fluid.dygraph.guard(place): + + strategy = fluid.dygraph.parallel.prepare_context() + mnist = MNIST("mnist") + adam = AdamOptimizer(learning_rate=0.001) + mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True) + train_reader = fluid.contrib.reader.distributed_batch_reader( + train_reader) + + for epoch in range(epoch_num): + for batch_id, data in enumerate(train_reader()): + dy_x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = to_variable(dy_x_data) + label = to_variable(y_data) + label.stop_gradient = True + + cost, acc = mnist(img, label) + + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + + avg_loss = mnist.scale_loss(avg_loss) + avg_loss.backward() + mnist.apply_collective_grads() + + adam.minimize(avg_loss) + mnist.clear_gradients() + if batch_id % 100 == 0 and batch_id is not 0: + print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy())) + +修改的地方主要有三处: +1. 需要从环境变量获取设备的ID,即: + + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + +2. 需要对原模型做一些预处理,即: + + strategy = fluid.dygraph.parallel.prepare_context() + mnist = MNIST("mnist") + adam = AdamOptimizer(learning_rate=0.001) + mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy) + +3. 数据读取,必须确保每个进程读取的数据是不同的,即所有进程读取数据的交集为空,所有进程读取数据的并集是完整的数据集: + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True) + train_reader = fluid.contrib.reader.distributed_batch_reader( + train_reader) + +4. 需要对loss进行调整,以及对参数的梯度进行聚合,即: + + avg_loss = mnist.scale_loss(avg_loss) + avg_loss.backward() + mnist.apply_collective_grads() + +Paddle动态图多进程多卡模型训练启动时需要指定使用的GPU,即如果使用`0,1,2,3`卡,启动方式如下: + + python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py + +输出结果为: + + ----------- Configuration Arguments ----------- + cluster_node_ips: 127.0.0.1 + log_dir: ./mylog + node_ip: 127.0.0.1 + print_config: True + selected_gpus: 0,1,2,3 + started_port: 6170 + training_script: train.py + training_script_args: ['--use_data_parallel', '1'] + use_paddlecloud: True + ------------------------------------------------ + trainers_endpoints: 127.0.0.1:6170,127.0.0.1:6171,127.0.0.1:6172,127.0.0.1:6173 , node_id: 0 , current_node_ip: 127.0.0.1 , num_nodes: 1 , node_ips: ['127.0.0.1'] , nranks: 4 + + +此时,程序会将每个进程的输出log导入到./mylog路径下: + + . + ├── mylog + │ ├── workerlog.0 + │ ├── workerlog.1 + │ ├── workerlog.2 + │ └── workerlog.3 + └── train.py + +如果不指定`--log_dir`,程序会将打印出所有进程的输出,即: + + ----------- Configuration Arguments ----------- + cluster_node_ips: 127.0.0.1 + log_dir: None + node_ip: 127.0.0.1 + print_config: True + selected_gpus: 0,1,2,3 + started_port: 6170 + training_script: train.py + training_script_args: ['--use_data_parallel', '1'] + use_paddlecloud: True + ------------------------------------------------ + trainers_endpoints: 127.0.0.1:6170,127.0.0.1:6171,127.0.0.1:6172,127.0.0.1:6173 , node_id: 0 , current_node_ip: 127.0.0.1 , num_nodes: 1 , node_ips: ['127.0.0.1'] , nranks: 4 + grep: warning: GREP_OPTIONS is deprecated; please use an alias or script + grep: warning: GREP_OPTIONS is deprecated; please use an alias or script + grep: warning: GREP_OPTIONS is deprecated; please use an alias or script + grep: warning: GREP_OPTIONS is deprecated; please use an alias or script + I0923 09:32:36.423513 56410 nccl_context.cc:120] init nccl context nranks: 4 local rank: 1 gpu id: 1 + I0923 09:32:36.425287 56411 nccl_context.cc:120] init nccl context nranks: 4 local rank: 2 gpu id: 2 + I0923 09:32:36.429337 56409 nccl_context.cc:120] init nccl context nranks: 4 local rank: 0 gpu id: 0 + I0923 09:32:36.429440 56412 nccl_context.cc:120] init nccl context nranks: 4 local rank: 3 gpu id: 3 + W0923 09:32:42.594097 56412 device_context.cc:198] Please NOTE: device: 3, CUDA Capability: 70, Driver API Version: 9.0, Runtime API Version: 9.0 + W0923 09:32:42.605836 56412 device_context.cc:206] device: 3, cuDNN Version: 7.5. + W0923 09:32:42.632463 56410 device_context.cc:198] Please NOTE: device: 1, CUDA Capability: 70, Driver API Version: 9.0, Runtime API Version: 9.0 + W0923 09:32:42.637948 56410 device_context.cc:206] device: 1, cuDNN Version: 7.5. + W0923 09:32:42.648674 56411 device_context.cc:198] Please NOTE: device: 2, CUDA Capability: 70, Driver API Version: 9.0, Runtime API Version: 9.0 + W0923 09:32:42.654021 56411 device_context.cc:206] device: 2, cuDNN Version: 7.5. + W0923 09:32:43.048696 56409 device_context.cc:198] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.0, Runtime API Version: 9.0 + W0923 09:32:43.053236 56409 device_context.cc:206] device: 0, cuDNN Version: 7.5. + start data reader (trainers_num: 4, trainer_id: 2) + start data reader (trainers_num: 4, trainer_id: 3) + start data reader (trainers_num: 4, trainer_id: 1) + start data reader (trainers_num: 4, trainer_id: 0) + Loss at epoch 0 step 0: [0.57390565] + Loss at epoch 0 step 0: [0.57523954] + Loss at epoch 0 step 0: [0.575606] + Loss at epoch 0 step 0: [0.5767452] + ## 模型参数的保存 @@ -465,7 +601,10 @@ Dygraph将非常适合和Numpy一起使用,使用`fluid.dygraph.to_variable(x) success = False print("model save and load success? {}".format(success)) - +需要注意的是,如果采用多卡训练,只需要一个进程对模型参数进行保存,因此在保存模型参数时,需要进行指定保存哪个进程的参数,比如 + + if fluid.dygraph.parallel.Env().local_rank == 0: + fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir") ## 模型评估