提交 aac8303d 编写于 作者: S sandyhouse

update, test=develop

上级 6c16858f
...@@ -138,7 +138,7 @@ void SectionWorker::TrainFiles() { ...@@ -138,7 +138,7 @@ void SectionWorker::TrainFiles() {
} }
} }
} }
} catch (platform::EOFException&) { } catch (platform::EOFException& e) {
// std::unique_lock<std::mutex> lk(thread_mutex); // std::unique_lock<std::mutex> lk(thread_mutex);
// threads_completed = true; // threads_completed = true;
VLOG(3) << "thread completed."; VLOG(3) << "thread completed.";
...@@ -146,6 +146,8 @@ void SectionWorker::TrainFiles() { ...@@ -146,6 +146,8 @@ void SectionWorker::TrainFiles() {
// thread_condition.notify_all(); // thread_condition.notify_all();
VLOG(3) << "EOF encountered"; VLOG(3) << "EOF encountered";
// throw platform::EOFException(); // throw platform::EOFException();
// throw e;
PADDLE_THROW_EOF();
break; break;
} }
} }
...@@ -303,7 +305,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -303,7 +305,7 @@ void SectionWorker::TrainFilesWithProfiler() {
<< micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]" << micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
<< std::endl; << std::endl;
} }
} catch (platform::EOFException&) { } catch (platform::EOFException& e) {
VLOG(3) << "thread completed."; VLOG(3) << "thread completed.";
VLOG(0) << "EOF encountered"; VLOG(0) << "EOF encountered";
VLOG(0) << "============timeline============"; VLOG(0) << "============timeline============";
...@@ -313,6 +315,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -313,6 +315,7 @@ void SectionWorker::TrainFilesWithProfiler() {
<< ", mean_time: " << op_total_time[i] / op_count[i]; << ", mean_time: " << op_total_time[i] / op_count[i];
} }
VLOG(0) << "================================"; VLOG(0) << "================================";
throw e;
break; break;
} }
} }
......
...@@ -40,23 +40,24 @@ class CRecvOp : public framework::OperatorWithKernel { ...@@ -40,23 +40,24 @@ class CRecvOp : public framework::OperatorWithKernel {
"The size of the output shape must be greater than 0 " "The size of the output shape must be greater than 0 "
"but the value given is %d.", "but the value given is %d.",
out_shape.size())); out_shape.size()));
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
VLOG(0) << "wow1"; VLOG(0) << "wow1";
std::string dtype = ctx.Attr<std::string>("dtype"); int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type; framework::proto::VarType::Type type;
if (dtype == "fp32") { if (dtype == framework::proto::VarType::FP32) {
type = framework::proto::VarType::FP32; type = framework::proto::VarType::FP32;
} else if (dtype == "fp64") { } else if (dtype == framework::proto::VarType::FP64) {
type = framework::proto::VarType::FP64; type = framework::proto::VarType::FP64;
} else if (dtype == "fp16") { } else if (dtype == framework::proto::VarType::FP16) {
type = framework::proto::VarType::FP16; type = framework::proto::VarType::FP16;
} else if (dtype == "int32") { } else if (dtype == framework::proto::VarType::INT32) {
type = framework::proto::VarType::INT32; type = framework::proto::VarType::INT32;
} else if (dtype == "int64") { } else if (dtype == framework::proto::VarType::INT64) {
type = framework::proto::VarType::INT64; type = framework::proto::VarType::INT64;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -75,9 +76,9 @@ class CRecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -75,9 +76,9 @@ class CRecvOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.") AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0); AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0);
AddAttr<std::string>("dtype", AddAttr<int>("dtype",
"(std::string default fp32) data type of tensor.") "(std::string default 5(float32)) data type of tensor.")
.SetDefault("fp32"); .SetDefault(5);
AddAttr<std::vector<int>>("out_shape", "shape of the output tensor.") AddAttr<std::vector<int>>("out_shape", "shape of the output tensor.")
.SetDefault(std::vector<int>()); .SetDefault(std::vector<int>());
AddAttr<bool>( AddAttr<bool>(
......
...@@ -25,37 +25,72 @@ namespace operators { ...@@ -25,37 +25,72 @@ namespace operators {
template <typename T> template <typename T>
class CRecvOpCUDAKernel : public framework::OpKernel<T> { class CRecvOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
VLOG(0) << "here1"; VLOG(0) << "here1";
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
VLOG(0) << "here2"; VLOG(0) << "here2";
auto out_shape = ctx.Attr<std::vector<int>>("out_shape"); // auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
auto out_dims = paddle::framework::make_ddim(out_shape); // auto out_dims = paddle::framework::make_ddim(out_shape);
int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
// if (data_type == framework::proto::VarType::FP32) {
// type = framework::proto::VarType::FP32;
//} else if (data_type == framework::proto::VarType::FP64) {
// type = framework::proto::VarType::FP64;
//} else if (data_type == framework::proto::VarType::FP16) {
// type = framework::proto::VarType::FP16;
//} else if (data_type == framework::proto::VarType::INT32) {
// type = framework::proto::VarType::INT32;
//} else if (data_type == framework::proto::VarType::INT64) {
// type = framework::proto::VarType::INT64;
//} else {
// PADDLE_THROW(platform::errors::InvalidArgument(
// "Unknown data type %s for c_recv op.", data_type));
//}
ncclDataType_t dtype = platform::ToNCCLDataType(type);
auto out_dims = out->dims();
int numel = 0;
int *numel_ptr = nullptr;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int)));
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
out->mutable_data<T>(out_dims, place); int peer = ctx.Attr<int>("peer");
VLOG(0) << "out_dims:" << out_dims; PADDLE_ENFORCE_LT(
ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); peer, comm->nranks(),
int numel = out->numel(); platform::errors::InvalidArgument("The value of peer (%d) you set must "
VLOG(0) << "numel:" << numel; "be less than comm->nranks (%d).",
peer, comm->nranks()));
cudaStream_t stream = nullptr; cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::CUDADeviceContext *>(dev_ctx)->stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
numel_ptr, 1, ncclInt, peer, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost));
VLOG(0) << "numel:" << numel;
VLOG(0) << "out_dims:" << out_dims;
int rest_numel = 1;
for (size_t i = 1; i < out_dims.size(); ++i) {
rest_numel = rest_numel * out_dims[i];
}
out_dims[0] = numel / rest_numel;
VLOG(0) << "out_dims:" << out_dims;
out->mutable_data<T>(out_dims, place);
// ncclDataType_t dtype = platform::ToNCCLDataType(out->type());
// numel = out->numel();
// VLOG(0) << "numel:" << numel;
int peer = ctx.Attr<int>("peer");
PADDLE_ENFORCE_LT(
peer, comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));
VLOG(0) << "here3"; VLOG(0) << "here3";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream)); out->data<T>(), numel, dtype, peer, comm->comm(), stream));
......
...@@ -49,9 +49,22 @@ class CSendOpCUDAKernel : public framework::OpKernel<T> { ...@@ -49,9 +49,22 @@ class CSendOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("The value of peer (%d) you set must " platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).", "be less than comm->nranks (%d).",
peer, comm->nranks())); peer, comm->nranks()));
int* numel_ptr = nullptr;
VLOG(0) << "numel: " << numel;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int)));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice));
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
VLOG(0) << "wawa1";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
numel_ptr, 1, ncclInt, peer, comm->comm(), stream));
VLOG(0) << "wawa2";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream)); x->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(0) << "wawa3";
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
VLOG(0) << "wawa4";
VLOG(3) << "rank " << comm->rank() << " send " VLOG(3) << "rank " << comm->rank() << " send "
<< framework::product(x->dims()) << " to " << peer; << framework::product(x->dims()) << " to " << peer;
#else #else
......
...@@ -3983,6 +3983,7 @@ class PipelineOptimizer(object): ...@@ -3983,6 +3983,7 @@ class PipelineOptimizer(object):
outputs={'Out': [new_var]}, outputs={'Out': [new_var]},
attrs={ attrs={
'out_shape': new_var.shape, 'out_shape': new_var.shape,
'dtype': new_var.dtype,
self._op_device_key: device, self._op_device_key: device,
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
'peer': first_dev_index 'peer': first_dev_index
...@@ -4137,7 +4138,7 @@ class PipelineOptimizer(object): ...@@ -4137,7 +4138,7 @@ class PipelineOptimizer(object):
attrs={ attrs={
self._op_device_key: prev_device_spec, self._op_device_key: prev_device_spec,
self._op_role_key: op_role, self._op_role_key: op_role,
'peer': prev_device_index 'peer': cur_device_index
}) })
extra_index += 1 extra_index += 1
block._insert_op( block._insert_op(
...@@ -4146,9 +4147,10 @@ class PipelineOptimizer(object): ...@@ -4146,9 +4147,10 @@ class PipelineOptimizer(object):
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape, 'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device_spec, self._op_device_key: cur_device_spec,
self._op_role_key: op_role, self._op_role_key: op_role,
'peer': cur_device_index 'peer': prev_device_index
}) })
extra_index += 1 extra_index += 1
...@@ -4324,6 +4326,7 @@ class PipelineOptimizer(object): ...@@ -4324,6 +4326,7 @@ class PipelineOptimizer(object):
outputs={'Out': [read_block.var(var_name)]}, outputs={'Out': [read_block.var(var_name)]},
attrs={ attrs={
'out_shape': read_block.var(var_name).shape, 'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device, self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every # A trick to make the role LRSched to avoid copy every
# microbatch # microbatch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册