提交 0d57ca46 编写于 作者: Y Yang Yang

nccl pass parallel_do test

上级 0815c0f1
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/framework.pb.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
......@@ -49,6 +50,22 @@ class NCCLInitOp : public framework::OperatorBase {
}
};
class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarDesc::NCCL_COM;
out_var.SetType(var_type);
}
};
class NCCLInitOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
......@@ -214,7 +231,9 @@ Bcast the tensors.
namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker,
ops::NCCLInitOpVarTypeInference,
ops::NCCLInitOpShapeInference);
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker);
......
......@@ -47,8 +47,11 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out");
LOG(INFO) << "------------------";
std::string reduction = ctx.Attr<std::string>("reduction");
LOG(INFO) << "------------------";
ncclRedOp_t reduction_op_ = ncclSum;
LOG(INFO) << "------------------";
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
......@@ -62,14 +65,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
LOG(INFO) << "------------------";
auto* comm = ctx.Input<Communicator>("Communicator");
LOG(INFO) << "------------------";
auto stream = ctx.cuda_device_context().stream();
LOG(INFO) << "------------------";
// device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
LOG(INFO) << "------------------";
int idx = comm->GetCommId(gpu_id);
LOG(INFO) << "------------------";
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : "
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
......
......@@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs";
static constexpr char kParallelScopes[] = "parallel_scopes";
static constexpr char kParallelBlock[] = "sub_block";
static constexpr char kUseNCCL[] = "use_nccl";
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
......@@ -159,6 +160,7 @@ class ParallelDoOp : public framework::OperatorBase {
}
WaitOnPlaces(places);
// PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size());
std::vector<std::future<void>> workers;
workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) {
......@@ -202,6 +204,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(kOutputs, "").AsDuplicable();
AddOutput(kParallelScopes, "");
AddAttr<framework::BlockDesc *>(kParallelBlock, "");
AddAttr<bool>(kUseNCCL, "true if we use nccl on backward")
.SetDefault(false);
AddComment(R"DOC(
ParallelDo Operator.
)DOC");
......@@ -223,20 +227,22 @@ class ParallelDoGradOp : public framework::OperatorBase {
auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
// PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size());
// feed output@grad
SplitTensorAndMoveTensorToScopes(
scope, const_cast<std::vector<framework::Scope *> *>(&sub_scopes),
places, Inputs(framework::GradVarName(kOutputs)));
WaitOnPlaces(places);
LOG(INFO) << "places " << places.size();
// exe run
std::vector<std::future<void>> workers;
for (size_t i = 0; i < sub_scopes.size(); ++i) {
auto &place = places[i];
auto *cur_scope = sub_scopes[i];
LOG(INFO) << place;
// execute
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
......@@ -245,12 +251,26 @@ class ParallelDoGradOp : public framework::OperatorBase {
false /*create_local_scope*/);
}));
}
LOG(INFO) << "places " << places.size();
for (auto &worker : workers) {
worker.wait();
}
WaitOnPlaces(places);
AccumulateGrad(scope, place, sub_scopes, places);
// NCCL allreduce op will be added by backward,
// so no need to explicitly accumulate grad
if (!(Attr<bool>(kUseNCCL))) {
AccumulateGrad(scope, place, sub_scopes, places);
} else {
for (auto &place : places) {
PADDLE_ENFORCE(platform::is_gpu_place(place),
"NCCL only supports cuda place");
}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
}
void AccumulateGrad(const framework::Scope &scope,
......
......@@ -218,7 +218,7 @@ def _callback_lookup_(op):
:param op:
:return: callback function
"""
if op.type == 'parallel_do':
if op.type == 'parallel_do' and op.attr('use_nccl'):
param_names = set(op.input('parameters'))
param_grad_names = [n + "@GRAD" for n in param_names]
......@@ -229,18 +229,25 @@ def _callback_lookup_(op):
def __call__(self, block, context):
if not self.has_inserted_nccl_init:
global_block = block.program.global_block()
op_desc = global_block.desc.append_op()
var_desc = global_block.desc.var('nccl_com')
var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
self.nccl_com = global_block.create_var(
name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
framework.Operator(
global_block,
type='ncclInit',
desc=op_desc,
inputs={},
outputs={'Communicator': [self.nccl_com]})
# global_block = block.program.global_block()
# op_desc = global_block.desc.append_op()
# var_desc = global_block.desc.var('nccl_com__do_not_change_')
# var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
# self.nccl_com = global_block.create_var(
# name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
# framework.Operator(
# global_block,
# type='ncclInit',
# desc=op_desc,
# inputs={},
# outputs={'Communicator': [self.nccl_com]})
op_desc = _create_op_desc_(
"ncclInit", {},
{"Communicator": ['nccl_com__do_not_change_']}, {})
# block.desc.append_op().copy_from(op_desc)
print(serialize_op_decs(op_desc))
block.program.global_block().desc.append_op().copy_from(
op_desc)
self.has_inserted_nccl_init = True
current_op_desc = context["__current_op_desc__"]
......@@ -263,7 +270,8 @@ def _callback_lookup_(op):
op_desc = _create_op_desc_(
"ncclAllReduce", {
"X": [o_argu],
"Communicator": ['nccl_com_0']
"Communicator":
['nccl_com__do_not_change_']
}, {"Out": [allreduce_out_name]},
{"reduction": "ncclSum"})
block.desc.append_op().copy_from(op_desc)
......@@ -375,10 +383,11 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
# infer_shape and infer_type
if op_desc.type() == 'ncclInit':
continue
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
# ncclInit dones't need to set data_type
if op_desc.type() == 'ncclInit':
continue
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_(arg, block)
......
......@@ -237,12 +237,13 @@ class ParallelDo(object):
ParallelDo class is used to create a ParallelDo.
"""
def __init__(self, places, name=None):
def __init__(self, places, use_nccl=False, name=None):
self.helper = LayerHelper("parallel_do", name=name)
self.inputs = []
self.places = places
self.outputs = []
self.status = StaticRNN.BEFORE_RNN_BLOCK
self.use_nccl = use_nccl
def do(self):
return BlockGuardWithCompletion(self)
......@@ -325,7 +326,8 @@ class ParallelDo(object):
},
outputs={'outputs': outputs,
'parallel_scopes': [step_scope]},
attrs={'sub_block': current_block})
attrs={'sub_block': current_block,
'use_nccl': self.use_nccl})
class BlockGuardWithCompletion(BlockGuard):
......
......@@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase):
fetch=fetch,
place=gpu,
use_parallel=True)
result_gpu_nccl = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=gpu,
use_parallel=True,
use_nccl=True)
self._assert_same_(fetch, result_cpu, result_cpu_parallel,
result_gpu, result_gpu_parallel)
result_gpu, result_gpu_parallel, result_gpu_nccl)
else:
self._assert_same_(fetch, result_cpu, result_cpu_parallel)
def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False):
def _run_test_impl_(self,
callback,
feed,
fetch,
place,
use_parallel=False,
use_nccl=False):
"""
Run a single test, returns the fetch values
Args:
......@@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase):
# Automatically insert parallel do if use_parallel = True
if use_parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
data = next(generator)
if isinstance(data, fluid.Variable):
......@@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase):
"""
def _impl_(a, b, fetch_id, item_id):
item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU']
item_str = [
'CPU', 'ParallelCPU', 'GPU', 'ParallelGPU', 'ParallelGPUNCCL'
]
flag = numpy.allclose(a, b, rtol=0.1)
self.assertTrue(flag, "The {0} are different in {1}".format(
fetch[fetch_id], item_str[item_id]))
......@@ -157,18 +172,10 @@ class ParallelOpTest(BaseParallelForTest):
loss = fluid.layers.mean(x=hidden)
yield loss
def test_simple_fc(self):
self.run_test(
callback=self.__network__,
feed={
'img': numpy.random.random(size=(51, 784)).astype('float32')
},
fetch=['fc1.w@GRAD'])
def test_fc_with_tiny_data(self):
self.run_test(
callback=self.__network__,
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
feed={'img': numpy.random.random(size=(8, 784)).astype('float32')},
fetch=['fc1.w@GRAD'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册