From 02175555306e3aabdfac39bca49b4d596aaf85f3 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Fri, 17 May 2019 14:43:07 +0800 Subject: [PATCH] polish parallel dygraph code (#17164) * add var grad hook test=develop --- paddle/fluid/imperative/layer.cc | 14 +- paddle/fluid/imperative/layer.h | 2 +- .../operators/distributed_ops/allreduce_op.h | 3 +- paddle/fluid/pybind/pybind.cc | 8 +- python/paddle/fluid/dygraph/parallel.py | 50 +-- .../tests/unittests/parallel_dygraph_mnist.py | 15 +- .../unittests/parallel_dygraph_se_resnext.py | 314 ++++++++++++++++++ .../fluid/tests/unittests/test_dist_base.py | 27 +- .../unittests/test_parallel_dygraph_mnist.py | 5 +- .../test_parallel_dygraph_se_resnext.py | 35 ++ 10 files changed, 411 insertions(+), 62 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index ad459a76a19..1dd20021e98 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -150,9 +150,9 @@ class Autograd { const std::vector& ingrads = it->second; for (size_t i = 0; i < ingrads.size(); ++i) { if (!ingrads[i]) continue; - if (ready_op->input_vars_[it->first][i]->IsStopGradient()) { - continue; - } + auto p = ready_op->input_vars_[it->first][i]; + + if (p->IsStopGradient()) continue; OpBase* pre_op = ready_op->pre_ops_[it->first][i]; if (!pre_op) continue; @@ -415,15 +415,11 @@ void OpBase::InvokeBackwardHooks() { } } -void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) { +void OpBase::RegisterBackwardHooks(const py::object& callable) { VLOG(3) << "Register backward hooks " << trace_id_; // TODO(minqiyang): check the callable format - if (front) { - backward_hooks_.insert(backward_hooks_.begin(), callable); - } else { - backward_hooks_.push_back(callable); - } + backward_hooks_.push_back(callable); } void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) { diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index e7e0a692e31..cfd43ef6ce3 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase { return grad_op_descs_[index]->Type(); } - void RegisterBackwardHooks(const py::object& callable, bool front = false); + void RegisterBackwardHooks(const py::object& callable); void InvokeBackwardHooks(); diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.h b/paddle/fluid/operators/distributed_ops/allreduce_op.h index 8c143867618..0275f6a9cf3 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.h +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.h @@ -39,6 +39,7 @@ class AllReduceOpKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); auto in = ctx.Input("X"); auto out = ctx.Output("Out"); + int dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); auto* sendbuff = in->data(); @@ -66,12 +67,10 @@ class AllReduceOpKernel : public framework::OpKernel { red_type = ncclMin; break; } - VLOG(0) << "call allreduce with type: " << reduce_type; PADDLE_ENFORCE(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, static_cast(dtype), red_type, comm, stream)); if (ctx.Attr("sync_mode")) { - VLOG(0) << "sync allreduce..."; cudaError_t e_sync = cudaStreamSynchronize(stream); if (e_sync != 0) { LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a792b24be3c..6795621781f 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -252,11 +252,9 @@ PYBIND11_MODULE(core, m) { py::class_(m, "OpBase", R"DOC()DOC") .def(py::init()) .def("register_backward_hooks", - [](imperative::OpBase &self, const py::object &callable, - bool front = false) { - self.RegisterBackwardHooks(callable, front); - }, - py::arg("callable"), py::arg("front") = false) + [](imperative::OpBase &self, const py::object &callable) { + self.RegisterBackwardHooks(callable); + }) .def_property("_trace_id", [](const imperative::OpBase &self) { pybind11::gil_scoped_release release; diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 44c20166b89..1378f914028 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -13,12 +13,14 @@ # limitations under the License. import os import six +import numpy as np from .. import core from . import layers from .. import framework from ..layers import collective +from . import to_variable __all__ = ["prepare_context"] @@ -75,31 +77,33 @@ class Env(object): class DataParallel(layers.Layer): - def __init__(self, layers): + def __init__(self, layers, strategy): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") self._layers = layers - - def build_once(self, *inputs, **kwargs): - #TODO(Yancey1989): broadcast all the paramters - pass + self._strategy = strategy def forward(self, *inputs, **kwargs): - def _collective_hook(iop): - op = framework._dygraph_tracer()._ops[iop._trace_id] - for k, v in six.iteritems(op.inputs): - for ivar in v: - g = ivar._grad_ivar() - if g: - g_var = framework.Variable( - block=self._helper.main_program.current_block(), - name=ivar._grad_name(), - stop_gradient=True, - ivar=g) - collective._allreduce(g_var, g_var, sync_mode=True) - - outs = self._layers(*inputs, **kwargs) - for _, op in six.iteritems(framework._dygraph_tracer()._ops): - # hook collective ops - op.iop.register_backward_hooks(_collective_hook, front=True) - return outs + return self._layers(*inputs, **kwargs) + + def scale_loss(self, loss): + if self._strategy.nranks < 2: + return loss + loss_scale = to_variable( + np.array([self._strategy.nranks]).astype("float32")) + loss_scale.stop_gradient = True + loss = loss / loss_scale + return loss + + def apply_collective_grads(self): + if self._strategy.nranks < 2: + return + + for param in self._layers.parameters(): + if param.trainable and param._ivar._grad_ivar(): + g_var = framework.Variable( + block=self._helper.main_program.current_block(), + name=param._ivar._grad_name(), + stop_gradient=True, + ivar=param._ivar._grad_ivar()) + collective._allreduce(g_var, g_var, sync_mode=True) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py index 8b9e2997ec7..d2ce14e92ad 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py @@ -101,11 +101,13 @@ class MNIST(fluid.dygraph.Layer): loc=0.0, scale=scale)), act="softmax") - def forward(self, inputs): + def forward(self, inputs, label): x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_2(x) - x = self._fc(x) - return x + cost = self._fc(x) + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + return avg_loss class TestMnist(TestParallelDyGraphRunnerBase): @@ -113,7 +115,7 @@ class TestMnist(TestParallelDyGraphRunnerBase): model = MNIST("mnist") train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=2, drop_last=True) - opt = SGDOptimizer(learning_rate=1e-3) + opt = fluid.optimizer.SGD(learning_rate=1e-3) return model, train_reader, opt def run_one_loop(self, model, opt, data): @@ -126,9 +128,8 @@ class TestMnist(TestParallelDyGraphRunnerBase): label = to_variable(y_data) label.stop_gradient = True - cost = model(img) - loss = fluid.layers.cross_entropy(cost, label) - avg_loss = fluid.layers.mean(loss) + avg_loss = model(img, label) + return avg_loss diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py new file mode 100644 index 00000000000..9eb860cb65f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py @@ -0,0 +1,314 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import contextlib +import unittest +import numpy as np +import six +import pickle +import sys + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, BatchNorm +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.layer_helper import LayerHelper + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + name_scope, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(ConvBNLayer, self).__init__(name_scope) + + self._conv = Conv2D( + self.full_name(), + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=None) + + self._batch_norm = BatchNorm( + self.full_name(), num_filters, act=act, momentum=0.1) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + + return y + + +class SqueezeExcitation(fluid.dygraph.Layer): + def __init__(self, name_scope, num_channels, reduction_ratio): + + super(SqueezeExcitation, self).__init__(name_scope) + self._pool = Pool2D( + self.full_name(), pool_size=0, pool_type='avg', global_pooling=True) + self._squeeze = FC( + self.full_name(), + size=num_channels // reduction_ratio, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.05)), + act='relu') + self._excitation = FC( + self.full_name(), + size=num_channels, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.05)), + act='sigmoid') + + def forward(self, input): + y = self._pool(input) + y = self._squeeze(y) + y = self._excitation(y) + y = fluid.layers.elementwise_mul(x=input, y=y, axis=0) + return y + + +class BottleneckBlock(fluid.dygraph.Layer): + def __init__(self, + name_scope, + num_channels, + num_filters, + stride, + cardinality, + reduction_ratio, + shortcut=True): + super(BottleneckBlock, self).__init__(name_scope) + + self.conv0 = ConvBNLayer( + self.full_name(), + num_channels=num_channels, + num_filters=num_filters, + filter_size=1) + self.conv1 = ConvBNLayer( + self.full_name(), + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + groups=cardinality) + self.conv2 = ConvBNLayer( + self.full_name(), + num_channels=num_filters, + num_filters=num_filters * 4, + filter_size=1, + act='relu') + + self.scale = SqueezeExcitation( + self.full_name(), + num_channels=num_filters * 4, + reduction_ratio=reduction_ratio) + + if not shortcut: + self.short = ConvBNLayer( + self.full_name(), + num_channels=num_channels, + num_filters=num_filters * 4, + filter_size=1, + stride=stride) + + self.shortcut = shortcut + + self._num_channels_out = num_filters * 4 + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) + scale = self.scale(conv2) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + y = fluid.layers.elementwise_add(x=short, y=scale) + + layer_helper = LayerHelper(self.full_name(), act='relu') + y = layer_helper.append_activation(y) + return y + + +class SeResNeXt(fluid.dygraph.Layer): + def __init__(self, name_scope, layers=50, class_dim=102): + super(SeResNeXt, self).__init__(name_scope) + + self.layers = layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 6, 3] + num_filters = [128, 256, 512, 1024] + self.conv0 = ConvBNLayer( + self.full_name(), + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu') + self.pool = Pool2D( + self.full_name(), + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + elif layers == 101: + cardinality = 32 + reduction_ratio = 16 + depth = [3, 4, 23, 3] + num_filters = [128, 256, 512, 1024] + self.conv0 = ConvBNLayer( + self.full_name(), + num_channels=3, + num_filters=3, + filter_size=7, + stride=2, + act='relu') + self.pool = Pool2D( + self.full_name(), + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + elif layers == 152: + cardinality = 64 + reduction_ratio = 16 + depth = [3, 8, 36, 3] + num_filters = [128, 256, 512, 1024] + self.conv0 = ConvBNLayer( + self.full_name(), + num_channels=3, + num_filters=3, + filter_size=7, + stride=2, + act='relu') + self.conv1 = ConvBNLayer( + self.full_name(), + num_channels=64, + num_filters=3, + filter_size=7, + stride=2, + act='relu') + self.conv2 = ConvBNLayer( + self.full_name(), + num_channels=64, + num_filters=3, + filter_size=7, + stride=2, + act='relu') + self.pool = Pool2D( + self.full_name(), + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + self.bottleneck_block_list = [] + num_channels = 64 + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + self.full_name(), + num_channels=num_channels, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + cardinality=cardinality, + reduction_ratio=reduction_ratio, + shortcut=shortcut)) + num_channels = bottleneck_block._num_channels_out + self.bottleneck_block_list.append(bottleneck_block) + shortcut = True + + self.pool2d_avg = Pool2D( + self.full_name(), pool_size=7, pool_type='avg', global_pooling=True) + import math + stdv = 1.0 / math.sqrt(2048 * 1.0) + + self.fc = FC(self.full_name(), + size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + + def forward(self, inputs, label): + if self.layers == 50 or self.layers == 101: + y = self.conv0(inputs) + y = self.pool(y) + elif self.layers == 152: + y = self.conv0(inputs) + y = self.conv1(inputs) + y = self.conv2(inputs) + y = self.pool(y) + + for bottleneck_block in self.bottleneck_block_list: + y = bottleneck_block(y) + y = self.pool2d_avg(y) + y = fluid.layers.dropout(y, dropout_prob=0.2, seed=1) + cost = self.fc(y) + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + return avg_loss + + +class TestSeResNeXt(TestParallelDyGraphRunnerBase): + def get_model(self): + model = SeResNeXt("se-resnext") + train_reader = paddle.batch( + paddle.dataset.flowers.test(use_xmap=False), + batch_size=2, + drop_last=True) + + opt = fluid.optimizer.SGD(learning_rate=1e-3) + return model, train_reader, opt + + def run_one_loop(self, model, opt, data): + bs = len(data) + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + y_data = np.array([x[1] for x in data]).astype('int64').reshape(bs, 1) + img = to_variable(dy_x_data) + label = to_variable(y_data) + label.stop_gradient = True + + loss = model(img, label) + return loss + + +if __name__ == "__main__": + runtime_main(TestSeResNeXt) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 6c7054e95ef..b479966d1fb 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -31,7 +31,7 @@ import paddle.fluid.dygraph as dygraph from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.parallel import DataParallel -RUN_STEP = 10 +RUN_STEP = 5 DEFAULT_BATCH_SIZE = 2 @@ -200,6 +200,7 @@ class TestParallelDyGraphRunnerBase(object): "train_one_loop should be implemented by the child classes.") def run_trainer(self, args): + seed = 90 device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace(device_id) @@ -217,32 +218,35 @@ class TestParallelDyGraphRunnerBase(object): with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed + np.random.seed(seed) + import random + random.seed = seed model, train_reader, opt = self.get_model() - nranks = len(args.endpoints.split(",")) if args.endpoints else 1 + if args.update_method == "nccl2": - sys.stderr.write("") - model = dygraph.parallel.DataParallel(model) strategy = dygraph.parallel.ParallelStrategy() strategy.nranks = nranks strategy.local_rank = args.trainer_id strategy.trainer_endpoints = args.endpoints.split(",") strategy.current_endpoint = args.current_endpoint dygraph.parallel.prepare_context(strategy) + model = dygraph.parallel.DataParallel(model, strategy) out_losses = [] for step_id, data in enumerate(train_reader()): data = _get_data(data) if step_id == RUN_STEP: break loss = self.run_one_loop(model, opt, data) + out_losses.append(loss.numpy()) - # FIXME(Yancey1989): scale the loss inplace - loss.stop_gradient = True - loss_scale = to_variable(np.array([nranks]).astype("float32")) - loss = loss / loss_scale + # FIXME(Yancey1989): scale the loss inplace + if args.update_method == "nccl2": + loss = model.scale_loss(loss) - out_losses.append(loss.numpy()) loss.backward() + if args.update_method == "nccl2": + model.apply_collective_grads() opt.minimize(loss) model.clear_gradients() @@ -663,9 +667,6 @@ class TestDistBase(unittest.TestCase): local_loss = local_losses[step_id] tr0_loss = tr0_losses[step_id] tr1_loss = tr1_losses[step_id] - dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) - if not self._dygraph: - # Parallel DyGraph already scaled the loss in training - dist_loss = dist_loss / 2 + dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2 print("=======", local_loss, ":", dist_loss[0], "=======") self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py index a08991986a7..19cd1577df4 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest from test_dist_base import TestDistBase +import paddle.fluid as fluid class TestParallelDygraphMnist(TestDistBase): @@ -24,8 +25,8 @@ class TestParallelDygraphMnist(TestDistBase): self._dygraph = True def test_mnist(self): - self.check_with_place( - "parallel_dygraph_mnist.py", delta=1e-5, check_error_log=True) + if fluid.core.is_compiled_with_cuda(): + self.check_with_place("parallel_dygraph_mnist.py", delta=1e-5) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py new file mode 100644 index 00000000000..3c804ee0722 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py @@ -0,0 +1,35 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import paddle.fluid as fluid + + +class TestParallelDygraphSeResNeXt(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_se_resnext(self): + # TODO(Yancey1989): BN and Dropout is related with batchsize, so the delta is the 1, + # try to remove the BN and Dropout in the network and using delta = 1e-5 + if fluid.core.is_compiled_with_cuda(): + self.check_with_place("parallel_dygraph_se_resnext.py", delta=1) + + +if __name__ == "__main__": + unittest.main() -- GitLab