未验证 提交 0b07eef1 编写于 作者: Y Yan Xu 提交者: GitHub

ParallelDyGraph with GPU collective mode (#16827)

implement dygraph.parallel.DataParallel to hook reduce op.
上级 1a4a51db
......@@ -336,11 +336,15 @@ void OpBase::InvokeBackwardHooks() {
}
}
void OpBase::RegisterBackwardHooks(const py::object& callable) {
void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) {
VLOG(3) << "Register backward hooks " << trace_id_;
// TODO(minqiyang): check the callable format
backward_hooks_.push_back(callable);
if (front) {
backward_hooks_.insert(backward_hooks_.begin(), callable);
} else {
backward_hooks_.push_back(callable);
}
}
void VarBase::RunBackward() {
......
......@@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase {
return grad_op_descs_[index]->Type();
}
void RegisterBackwardHooks(const py::object& callable);
void RegisterBackwardHooks(const py::object& callable, bool front = false);
void InvokeBackwardHooks();
......
......@@ -15,91 +15,22 @@ limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
namespace paddle {
namespace operators {
struct MutableDataFunctor {
MutableDataFunctor(void** data, framework::LoDTensor* tensor,
const platform::Place& place)
: data_(data), tensor_(tensor), place_(place) {}
template <typename T>
void apply() {
*data_ = tensor_->mutable_data<T>(place_);
}
class AllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void** data_;
framework::LoDTensor* tensor_;
platform::Place place_;
};
void InferShape(framework::InferShapeContext* ctx) const override {}
class AllReduceOp : public framework::OperatorBase {
using OperatorBase::OperatorBase;
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"AllReduce op can run on gpu place only for now.");
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(place);
auto in_names = Inputs("X");
auto out_names = Outputs("Out");
PADDLE_ENFORCE_EQ(in_names.size(), 1, "Only support one input");
PADDLE_ENFORCE_EQ(out_names.size(), 1, "Only support one output");
auto* in = scope.FindVar(in_names[0]);
auto* out = scope.FindVar(out_names[0]);
PADDLE_ENFORCE(in->IsType<framework::LoDTensor>() ||
out->IsType<framework::LoDTensor>(),
"Only support allreduce LoDTensors");
int dtype = -1;
auto in_tensor = in->Get<framework::LoDTensor>();
dtype = platform::ToNCCLDataType(in_tensor.type());
int64_t numel = in_tensor.numel();
auto* sendbuff = in_tensor.data<void>();
auto* out_tensor = out->GetMutable<framework::LoDTensor>();
out_tensor->Resize(in_tensor.dims());
void* recvbuff = nullptr;
framework::VisitDataType(in_tensor.type(),
MutableDataFunctor(&recvbuff, out_tensor, place));
auto cuda_ctx = static_cast<platform::CUDADeviceContext*>(ctx);
auto* comm = cuda_ctx->nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = cuda_ctx->stream();
int reduce_type = Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
#endif
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -110,6 +41,10 @@ class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor) the result of allreduced.");
AddAttr<int>("reduce_type", "(int) determin the reduce type.")
.SetDefault(0);
AddAttr<bool>(
"sync_mode",
"(bool) whether to synchronize the CUDA stream after nccl call.")
.SetDefault(false);
AddComment(R"DOC(
***AllReduce Operator***
......@@ -128,16 +63,18 @@ If input and output are the same variable, in-place allreduce will be used.
}
};
class AllReduceOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(allreduce, ops::AllReduceOp,
ops::AllReduceOpMaker);
REGISTER_OPERATOR(allreduce, ops::AllReduceOp,
paddle::framework::EmptyGradOpMaker, ops::AllReduceOpMaker,
ops::AllReduceOpShapeInference);
REGISTER_OP_CPU_KERNEL(
allreduce, ops::AllReduceOpKernel<plat::CPUDeviceContext, float>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, double>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, int>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, int64_t>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
allreduce, ops::AllReduceOpKernel<plat::CUDADeviceContext, float>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, double>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, int>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"AllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
auto* sendbuff = in->data<void>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
auto* comm = dev_ctx.nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly.");
int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
VLOG(0) << "call allreduce with type: " << reduce_type;
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
if (ctx.Attr<bool>("sync_mode")) {
VLOG(0) << "sync allreduce...";
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
......@@ -236,9 +236,11 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<const std::string &>())
.def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable);
})
[](imperative::OpBase &self, const py::object &callable,
bool front = false) {
self.RegisterBackwardHooks(callable, front);
},
py::arg("callable"), py::arg("front") = false)
.def_property("_trace_id",
[](const imperative::OpBase &self) {
pybind11::gil_scoped_release release;
......
......@@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
from .. import core
from . import layers
from .. import framework
from ..layers import collective
__all__ = ["prepare_context"]
......@@ -21,9 +27,13 @@ ParallelStrategy = core.ParallelStrategy
__parallel_ctx__clz__ = None
def prepare_context(parallel_strategy, place):
def prepare_context(parallel_strategy):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once."
assert framework.in_dygraph_mode(
) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode."
place = framework._current_expected_place()
assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."
if isinstance(place, core.CUDAPlace):
__parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy,
......@@ -58,3 +68,38 @@ class Env(object):
@property
def current_endpoint(self):
return self._current_endpoint
@property
def trainer_endpoints(self):
return self._trainer_endpoints
class DataParallel(layers.Layer):
def __init__(self, layers):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")
self._layers = layers
def build_once(self, *inputs, **kwargs):
#TODO(Yancey1989): broadcast all the paramters
pass
def forward(self, *inputs, **kwargs):
def _collective_hook(iop):
op = framework._dygraph_tracer()._ops[iop._trace_id]
for k, v in six.iteritems(op.inputs):
for ivar in v:
g = ivar._grad_ivar()
if g:
g_var = framework.Variable(
block=self._helper.main_program.current_block(),
name=ivar._grad_name(),
stop_gradient=True,
ivar=g)
collective._allreduce(g_var, g_var, sync_mode=True)
outs = self._layers(*inputs, **kwargs)
for _, op in six.iteritems(framework._dygraph_tracer()._ops):
# hook collective ops
op.iop.register_backward_hooks(_collective_hook, front=True)
return outs
......@@ -16,7 +16,7 @@ from __future__ import print_function
from ..layer_helper import LayerHelper, unique_name
def _allreduce(x, out=None, reduce_type="sum"):
def _allreduce(x, out=None, reduce_type="sum", sync_mode=False):
helper = LayerHelper("allreduce", **locals())
# Convert string reduce type to op int type
red_typ_int = 0
......@@ -43,5 +43,6 @@ def _allreduce(x, out=None, reduce_type="sum"):
type='allreduce',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={"reduce_type": red_typ_int})
attrs={"reduce_type": red_typ_int,
"sync_mode": sync_mode})
return out
......@@ -19,6 +19,7 @@ endif(NOT WITH_DISTRIBUTE)
if (NOT ${WITH_GPU})
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future
elseif(${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
endif()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import contextlib
import unittest
import numpy as np
import six
import pickle
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
return x
class TestMnist(TestParallelDyGraphRunnerBase):
def get_model(self):
model = MNIST("mnist")
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=2, drop_last=True)
opt = SGDOptimizer(learning_rate=1e-3)
return model, train_reader, opt
def run_one_loop(self, model, opt, data):
batch_size = len(data)
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
cost = model(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
if __name__ == "__main__":
runtime_main(TestMnist)
......@@ -27,6 +27,9 @@ import numpy as np
import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.dygraph as dygraph
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import DataParallel
RUN_STEP = 10
DEFAULT_BATCH_SIZE = 2
......@@ -187,6 +190,68 @@ class TestDistRunnerBase(object):
sys.stdout.buffer.write(pickle.dumps(out_losses))
class TestParallelDyGraphRunnerBase(object):
def get_model(self):
raise NotImplementedError(
"get_model should be implemented by child classes.")
def run_one_loop(self, model, opt, data):
raise NotImplementedError(
"train_one_loop should be implemented by the child classes.")
def run_trainer(self, args):
seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
def _get_data(batch):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
model, train_reader, opt = self.get_model()
nranks = len(args.endpoints.split(",")) if args.endpoints else 1
if args.update_method == "nccl2":
sys.stderr.write("")
model = dygraph.parallel.DataParallel(model)
strategy = dygraph.parallel.ParallelStrategy()
strategy.nranks = nranks
strategy.local_rank = args.trainer_id
strategy.trainer_endpoints = args.endpoints.split(",")
strategy.current_endpoint = args.current_endpoint
dygraph.parallel.prepare_context(strategy)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = _get_data(data)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
# FIXME(Yancey1989): scale the loss inplace
loss.stop_gradient = True
loss_scale = to_variable(np.array([nranks]).astype("float32"))
loss = loss / loss_scale
out_losses.append(loss.numpy())
loss.backward()
opt.minimize(loss)
model.clear_gradients()
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run dist test.')
parser.add_argument(
......@@ -275,6 +340,7 @@ class TestDistBase(unittest.TestCase):
self._nccl2_reduce_layer = False
self._lr = 0.001
self._use_dgc = False
self._dygraph = False
self._setup_config()
self._after_setup_config()
......@@ -597,6 +663,9 @@ class TestDistBase(unittest.TestCase):
local_loss = local_losses[step_id]
tr0_loss = tr0_losses[step_id]
tr1_loss = tr1_losses[step_id]
dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss]))
if not self._dygraph:
# Parallel DyGraph already scaled the loss in training
dist_loss = dist_loss / 2
print("=======", local_loss, ":", dist_loss[0], "=======")
self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
class TestParallelDygraphMnist(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_mnist(self):
self.check_with_place(
"parallel_dygraph_mnist.py", delta=1e-5, check_error_log=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册