diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index bc03285a4c5fe6db2abf2b271d6ddc86e75a9412..aa739a8972ec1bf6806fe0d5a3e5e4fd1d6f807d 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -336,11 +336,15 @@ void OpBase::InvokeBackwardHooks() { } } -void OpBase::RegisterBackwardHooks(const py::object& callable) { +void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) { VLOG(3) << "Register backward hooks " << trace_id_; // TODO(minqiyang): check the callable format - backward_hooks_.push_back(callable); + if (front) { + backward_hooks_.insert(backward_hooks_.begin(), callable); + } else { + backward_hooks_.push_back(callable); + } } void VarBase::RunBackward() { diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 316095b4fb8ed8a574bb25847fa29bd45d9b50f0..37488d381ef2fe15f96a5b55434eca40466a1424 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase { return grad_op_descs_[index]->Type(); } - void RegisterBackwardHooks(const py::object& callable); + void RegisterBackwardHooks(const py::object& callable, bool front = false); void InvokeBackwardHooks(); diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.cc b/paddle/fluid/operators/distributed_ops/allreduce_op.cc index 0fbc27515cec9f7982852954055aa929f678a096..57d68eb931f089e46df07f45186246568bc297c8 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.cc +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.cc @@ -15,91 +15,22 @@ limitations under the License. */ #include // NOLINT #include -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/nccl_helper.h" -#endif +#include "paddle/fluid/operators/distributed_ops/allreduce_op.h" namespace paddle { namespace operators { -struct MutableDataFunctor { - MutableDataFunctor(void** data, framework::LoDTensor* tensor, - const platform::Place& place) - : data_(data), tensor_(tensor), place_(place) {} - - template - void apply() { - *data_ = tensor_->mutable_data(place_); - } +class AllReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; - void** data_; - framework::LoDTensor* tensor_; - platform::Place place_; -}; + void InferShape(framework::InferShapeContext* ctx) const override {} -class AllReduceOp : public framework::OperatorBase { - using OperatorBase::OperatorBase; - - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - PADDLE_ENFORCE(is_gpu_place(place), - "AllReduce op can run on gpu place only for now."); -#ifdef PADDLE_WITH_CUDA - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* ctx = pool.Get(place); - auto in_names = Inputs("X"); - auto out_names = Outputs("Out"); - PADDLE_ENFORCE_EQ(in_names.size(), 1, "Only support one input"); - PADDLE_ENFORCE_EQ(out_names.size(), 1, "Only support one output"); - - auto* in = scope.FindVar(in_names[0]); - auto* out = scope.FindVar(out_names[0]); - - PADDLE_ENFORCE(in->IsType() || - out->IsType(), - "Only support allreduce LoDTensors"); - - int dtype = -1; - auto in_tensor = in->Get(); - dtype = platform::ToNCCLDataType(in_tensor.type()); - - int64_t numel = in_tensor.numel(); - auto* sendbuff = in_tensor.data(); - auto* out_tensor = out->GetMutable(); - out_tensor->Resize(in_tensor.dims()); - void* recvbuff = nullptr; - framework::VisitDataType(in_tensor.type(), - MutableDataFunctor(&recvbuff, out_tensor, place)); - - auto cuda_ctx = static_cast(ctx); - auto* comm = cuda_ctx->nccl_comm(); - // FIXME(typhoonzero): should use nccl stream here. - auto stream = cuda_ctx->stream(); - - int reduce_type = Attr("reduce_type"); - ncclRedOp_t red_type = ncclSum; - switch (reduce_type) { - case 0: - red_type = ncclSum; - break; - case 1: - red_type = ncclProd; - break; - case 2: - red_type = ncclMax; - break; - case 3: - red_type = ncclMin; - break; - } - - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - sendbuff, recvbuff, numel, static_cast(dtype), red_type, - comm, stream)); -#endif + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; @@ -110,6 +41,10 @@ class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor) the result of allreduced."); AddAttr("reduce_type", "(int) determin the reduce type.") .SetDefault(0); + AddAttr( + "sync_mode", + "(bool) whether to synchronize the CUDA stream after nccl call.") + .SetDefault(false); AddComment(R"DOC( ***AllReduce Operator*** @@ -128,16 +63,18 @@ If input and output are the same variable, in-place allreduce will be used. } }; -class AllReduceOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override {} -}; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(allreduce, ops::AllReduceOp, + ops::AllReduceOpMaker); -REGISTER_OPERATOR(allreduce, ops::AllReduceOp, - paddle::framework::EmptyGradOpMaker, ops::AllReduceOpMaker, - ops::AllReduceOpShapeInference); +REGISTER_OP_CPU_KERNEL( + allreduce, ops::AllReduceOpKernel, + ops::AllReduceOpKernel, + ops::AllReduceOpKernel, + ops::AllReduceOpKernel, + ops::AllReduceOpKernel); diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc b/paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b70f78399026b9f853b8315f0acf6dbad64242a --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/distributed_ops/allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + allreduce, ops::AllReduceOpKernel, + ops::AllReduceOpKernel, + ops::AllReduceOpKernel, + ops::AllReduceOpKernel, + ops::AllReduceOpKernel); diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.h b/paddle/fluid/operators/distributed_ops/allreduce_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8c143867618577740a29f971ac558c50113dff85 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class AllReduceOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), + "AllReduce op can run on gpu place only for now."); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto& dev_ctx = ctx.template device_context(); + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + int dtype = platform::ToNCCLDataType(in->type()); + int64_t numel = in->numel(); + auto* sendbuff = in->data(); + out->Resize(in->dims()); + void* recvbuff = out->mutable_data(place); + + auto* comm = dev_ctx.nccl_comm(); + // FIXME(typhoonzero): should use nccl stream here. + auto stream = dev_ctx.stream(); + PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly."); + + int reduce_type = ctx.Attr("reduce_type"); + ncclRedOp_t red_type = ncclSum; + switch (reduce_type) { + case 0: + red_type = ncclSum; + break; + case 1: + red_type = ncclProd; + break; + case 2: + red_type = ncclMax; + break; + case 3: + red_type = ncclMin; + break; + } + VLOG(0) << "call allreduce with type: " << reduce_type; + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, static_cast(dtype), red_type, + comm, stream)); + if (ctx.Attr("sync_mode")) { + VLOG(0) << "sync allreduce..."; + cudaError_t e_sync = cudaStreamSynchronize(stream); + if (e_sync != 0) { + LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync); + } + } +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 45542b7d018e0a7e1565624e57f866fdec8bbe0f..0827ef9f4566e30218fbbbe1a4dd7625575c3dfb 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -236,9 +236,11 @@ PYBIND11_MODULE(core, m) { py::class_(m, "OpBase", R"DOC()DOC") .def(py::init()) .def("register_backward_hooks", - [](imperative::OpBase &self, const py::object &callable) { - self.RegisterBackwardHooks(callable); - }) + [](imperative::OpBase &self, const py::object &callable, + bool front = false) { + self.RegisterBackwardHooks(callable, front); + }, + py::arg("callable"), py::arg("front") = false) .def_property("_trace_id", [](const imperative::OpBase &self) { pybind11::gil_scoped_release release; diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index f7decac963f47ba1dcc33e9c8eab7900e745d1df..44c20166b89906093e2211ed141754d8e6d0424a 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import six + from .. import core +from . import layers +from .. import framework + +from ..layers import collective __all__ = ["prepare_context"] @@ -21,9 +27,13 @@ ParallelStrategy = core.ParallelStrategy __parallel_ctx__clz__ = None -def prepare_context(parallel_strategy, place): +def prepare_context(parallel_strategy): global __parallel_ctx__clz__ assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once." + assert framework.in_dygraph_mode( + ) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode." + place = framework._current_expected_place() + assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard." if isinstance(place, core.CUDAPlace): __parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy, @@ -58,3 +68,38 @@ class Env(object): @property def current_endpoint(self): return self._current_endpoint + + @property + def trainer_endpoints(self): + return self._trainer_endpoints + + +class DataParallel(layers.Layer): + def __init__(self, layers): + super(DataParallel, + self).__init__(layers.full_name() + "_data_parallel") + self._layers = layers + + def build_once(self, *inputs, **kwargs): + #TODO(Yancey1989): broadcast all the paramters + pass + + def forward(self, *inputs, **kwargs): + def _collective_hook(iop): + op = framework._dygraph_tracer()._ops[iop._trace_id] + for k, v in six.iteritems(op.inputs): + for ivar in v: + g = ivar._grad_ivar() + if g: + g_var = framework.Variable( + block=self._helper.main_program.current_block(), + name=ivar._grad_name(), + stop_gradient=True, + ivar=g) + collective._allreduce(g_var, g_var, sync_mode=True) + + outs = self._layers(*inputs, **kwargs) + for _, op in six.iteritems(framework._dygraph_tracer()._ops): + # hook collective ops + op.iop.register_backward_hooks(_collective_hook, front=True) + return outs diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index a9bce77b9d4ae8d5b08c8c4433e5010f20383cc1..97c290f5a99da513740a79dae6a769c8214cae66 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -16,7 +16,7 @@ from __future__ import print_function from ..layer_helper import LayerHelper, unique_name -def _allreduce(x, out=None, reduce_type="sum"): +def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): helper = LayerHelper("allreduce", **locals()) # Convert string reduce type to op int type red_typ_int = 0 @@ -43,5 +43,6 @@ def _allreduce(x, out=None, reduce_type="sum"): type='allreduce', inputs={'X': [x]}, outputs={'Out': [out]}, - attrs={"reduce_type": red_typ_int}) + attrs={"reduce_type": red_typ_int, + "sync_mode": sync_mode}) return out diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6885bb158c78644af061ba07f892de28c462e150..f4f6bebcc7d2a66643646308716ab27e97d4a410 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -19,6 +19,7 @@ endif(NOT WITH_DISTRIBUTE) if (NOT ${WITH_GPU}) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future elseif(${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) endif() diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9e2997ec702882b0e374cefd47b1c02343b225 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py @@ -0,0 +1,136 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import contextlib +import unittest +import numpy as np +import six +import pickle + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC +from paddle.fluid.dygraph.base import to_variable + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + + +class SimpleImgConvPool(fluid.dygraph.Layer): + def __init__(self, + name_scope, + num_channels, + num_filters, + filter_size, + pool_size, + pool_stride, + pool_padding=0, + pool_type='max', + global_pooling=False, + conv_stride=1, + conv_padding=0, + conv_dilation=1, + conv_groups=1, + act=None, + use_cudnn=False, + param_attr=None, + bias_attr=None): + super(SimpleImgConvPool, self).__init__(name_scope) + + self._conv2d = Conv2D( + self.full_name(), + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=conv_stride, + padding=conv_padding, + dilation=conv_dilation, + groups=conv_groups, + param_attr=None, + bias_attr=None, + use_cudnn=use_cudnn) + + self._pool2d = Pool2D( + self.full_name(), + pool_size=pool_size, + pool_type=pool_type, + pool_stride=pool_stride, + pool_padding=pool_padding, + global_pooling=global_pooling, + use_cudnn=use_cudnn) + + def forward(self, inputs): + x = self._conv2d(inputs) + x = self._pool2d(x) + return x + + +class MNIST(fluid.dygraph.Layer): + def __init__(self, name_scope): + super(MNIST, self).__init__(name_scope) + + self._simple_img_conv_pool_1 = SimpleImgConvPool( + self.full_name(), 1, 20, 5, 2, 2, act="relu") + + self._simple_img_conv_pool_2 = SimpleImgConvPool( + self.full_name(), 20, 50, 5, 2, 2, act="relu") + + pool_2_shape = 50 * 4 * 4 + SIZE = 10 + scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 + self._fc = FC(self.full_name(), + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") + + def forward(self, inputs): + x = self._simple_img_conv_pool_1(inputs) + x = self._simple_img_conv_pool_2(x) + x = self._fc(x) + return x + + +class TestMnist(TestParallelDyGraphRunnerBase): + def get_model(self): + model = MNIST("mnist") + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=2, drop_last=True) + opt = SGDOptimizer(learning_rate=1e-3) + return model, train_reader, opt + + def run_one_loop(self, model, opt, data): + batch_size = len(data) + dy_x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(batch_size, 1) + img = to_variable(dy_x_data) + label = to_variable(y_data) + label.stop_gradient = True + + cost = model(img) + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + return avg_loss + + +if __name__ == "__main__": + runtime_main(TestMnist) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index a64cfe5cbfcc51721da797db619612fda94f4fc9..6c7054e95efa7eefd574bc9025e23908dd4ac7b1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -27,6 +27,9 @@ import numpy as np import paddle.fluid as fluid from paddle.fluid import compiler +import paddle.fluid.dygraph as dygraph +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dygraph.parallel import DataParallel RUN_STEP = 10 DEFAULT_BATCH_SIZE = 2 @@ -187,6 +190,68 @@ class TestDistRunnerBase(object): sys.stdout.buffer.write(pickle.dumps(out_losses)) +class TestParallelDyGraphRunnerBase(object): + def get_model(self): + raise NotImplementedError( + "get_model should be implemented by child classes.") + + def run_one_loop(self, model, opt, data): + raise NotImplementedError( + "train_one_loop should be implemented by the child classes.") + + def run_trainer(self, args): + seed = 90 + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + + def _get_data(batch): + if args.update_method != "local": + new_batch = [] + for offset, item in enumerate(batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch + else: + return batch + + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + model, train_reader, opt = self.get_model() + + nranks = len(args.endpoints.split(",")) if args.endpoints else 1 + if args.update_method == "nccl2": + sys.stderr.write("") + model = dygraph.parallel.DataParallel(model) + strategy = dygraph.parallel.ParallelStrategy() + strategy.nranks = nranks + strategy.local_rank = args.trainer_id + strategy.trainer_endpoints = args.endpoints.split(",") + strategy.current_endpoint = args.current_endpoint + dygraph.parallel.prepare_context(strategy) + out_losses = [] + for step_id, data in enumerate(train_reader()): + data = _get_data(data) + if step_id == RUN_STEP: + break + loss = self.run_one_loop(model, opt, data) + + # FIXME(Yancey1989): scale the loss inplace + loss.stop_gradient = True + loss_scale = to_variable(np.array([nranks]).astype("float32")) + loss = loss / loss_scale + + out_losses.append(loss.numpy()) + loss.backward() + + opt.minimize(loss) + model.clear_gradients() + if six.PY2: + print(pickle.dumps(out_losses)) + else: + sys.stdout.buffer.write(pickle.dumps(out_losses)) + + def runtime_main(test_class): parser = argparse.ArgumentParser(description='Run dist test.') parser.add_argument( @@ -275,6 +340,7 @@ class TestDistBase(unittest.TestCase): self._nccl2_reduce_layer = False self._lr = 0.001 self._use_dgc = False + self._dygraph = False self._setup_config() self._after_setup_config() @@ -597,6 +663,9 @@ class TestDistBase(unittest.TestCase): local_loss = local_losses[step_id] tr0_loss = tr0_losses[step_id] tr1_loss = tr1_losses[step_id] - dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2 + dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) + if not self._dygraph: + # Parallel DyGraph already scaled the loss in training + dist_loss = dist_loss / 2 print("=======", local_loss, ":", dist_loss[0], "=======") self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..a08991986a7ccbfc446d4dcab9a88b926ef6eea8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py @@ -0,0 +1,32 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + + +class TestParallelDygraphMnist(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = True + + def test_mnist(self): + self.check_with_place( + "parallel_dygraph_mnist.py", delta=1e-5, check_error_log=True) + + +if __name__ == "__main__": + unittest.main()