提交 aac8303d 编写于 作者: S sandyhouse

update, test=develop

上级 6c16858f
......@@ -138,7 +138,7 @@ void SectionWorker::TrainFiles() {
}
}
}
} catch (platform::EOFException&) {
} catch (platform::EOFException& e) {
// std::unique_lock<std::mutex> lk(thread_mutex);
// threads_completed = true;
VLOG(3) << "thread completed.";
......@@ -146,6 +146,8 @@ void SectionWorker::TrainFiles() {
// thread_condition.notify_all();
VLOG(3) << "EOF encountered";
// throw platform::EOFException();
// throw e;
PADDLE_THROW_EOF();
break;
}
}
......@@ -303,7 +305,7 @@ void SectionWorker::TrainFilesWithProfiler() {
<< micro_end.tv_sec * 1e6 + micro_end.tv_usec << "]"
<< std::endl;
}
} catch (platform::EOFException&) {
} catch (platform::EOFException& e) {
VLOG(3) << "thread completed.";
VLOG(0) << "EOF encountered";
VLOG(0) << "============timeline============";
......@@ -313,6 +315,7 @@ void SectionWorker::TrainFilesWithProfiler() {
<< ", mean_time: " << op_total_time[i] / op_count[i];
}
VLOG(0) << "================================";
throw e;
break;
}
}
......
......@@ -40,23 +40,24 @@ class CRecvOp : public framework::OperatorWithKernel {
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
VLOG(0) << "wow1";
std::string dtype = ctx.Attr<std::string>("dtype");
int dtype = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type;
if (dtype == "fp32") {
if (dtype == framework::proto::VarType::FP32) {
type = framework::proto::VarType::FP32;
} else if (dtype == "fp64") {
} else if (dtype == framework::proto::VarType::FP64) {
type = framework::proto::VarType::FP64;
} else if (dtype == "fp16") {
} else if (dtype == framework::proto::VarType::FP16) {
type = framework::proto::VarType::FP16;
} else if (dtype == "int32") {
} else if (dtype == framework::proto::VarType::INT32) {
type = framework::proto::VarType::INT32;
} else if (dtype == "int64") {
} else if (dtype == framework::proto::VarType::INT64) {
type = framework::proto::VarType::INT64;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -75,9 +76,9 @@ class CRecvOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("peer", "(int default 0) rank id for sender.").SetDefault(0);
AddAttr<std::string>("dtype",
"(std::string default fp32) data type of tensor.")
.SetDefault("fp32");
AddAttr<int>("dtype",
"(std::string default 5(float32)) data type of tensor.")
.SetDefault(5);
AddAttr<std::vector<int>>("out_shape", "shape of the output tensor.")
.SetDefault(std::vector<int>());
AddAttr<bool>(
......
......@@ -25,37 +25,72 @@ namespace operators {
template <typename T>
class CRecvOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
#if defined(PADDLE_WITH_NCCL)
VLOG(0) << "here1";
auto out = ctx.Output<framework::LoDTensor>("Out");
VLOG(0) << "here2";
auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
auto out_dims = paddle::framework::make_ddim(out_shape);
// auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
// auto out_dims = paddle::framework::make_ddim(out_shape);
int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
// if (data_type == framework::proto::VarType::FP32) {
// type = framework::proto::VarType::FP32;
//} else if (data_type == framework::proto::VarType::FP64) {
// type = framework::proto::VarType::FP64;
//} else if (data_type == framework::proto::VarType::FP16) {
// type = framework::proto::VarType::FP16;
//} else if (data_type == framework::proto::VarType::INT32) {
// type = framework::proto::VarType::INT32;
//} else if (data_type == framework::proto::VarType::INT64) {
// type = framework::proto::VarType::INT64;
//} else {
// PADDLE_THROW(platform::errors::InvalidArgument(
// "Unknown data type %s for c_recv op.", data_type));
//}
ncclDataType_t dtype = platform::ToNCCLDataType(type);
auto out_dims = out->dims();
int numel = 0;
int *numel_ptr = nullptr;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int)));
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
out->mutable_data<T>(out_dims, place);
VLOG(0) << "out_dims:" << out_dims;
ncclDataType_t dtype = platform::ToNCCLDataType(out->type());
int numel = out->numel();
VLOG(0) << "numel:" << numel;
int peer = ctx.Attr<int>("peer");
PADDLE_ENFORCE_LT(
peer, comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
stream = static_cast<platform::CUDADeviceContext *>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
numel_ptr, 1, ncclInt, peer, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost));
VLOG(0) << "numel:" << numel;
VLOG(0) << "out_dims:" << out_dims;
int rest_numel = 1;
for (size_t i = 1; i < out_dims.size(); ++i) {
rest_numel = rest_numel * out_dims[i];
}
out_dims[0] = numel / rest_numel;
VLOG(0) << "out_dims:" << out_dims;
out->mutable_data<T>(out_dims, place);
// ncclDataType_t dtype = platform::ToNCCLDataType(out->type());
// numel = out->numel();
// VLOG(0) << "numel:" << numel;
int peer = ctx.Attr<int>("peer");
PADDLE_ENFORCE_LT(
peer, comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));
VLOG(0) << "here3";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
......
......@@ -49,9 +49,22 @@ class CSendOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));
int* numel_ptr = nullptr;
VLOG(0) << "numel: " << numel;
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int)));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice));
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
VLOG(0) << "wawa1";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
numel_ptr, 1, ncclInt, peer, comm->comm(), stream));
VLOG(0) << "wawa2";
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(0) << "wawa3";
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
VLOG(0) << "wawa4";
VLOG(3) << "rank " << comm->rank() << " send "
<< framework::product(x->dims()) << " to " << peer;
#else
......
......@@ -3983,6 +3983,7 @@ class PipelineOptimizer(object):
outputs={'Out': [new_var]},
attrs={
'out_shape': new_var.shape,
'dtype': new_var.dtype,
self._op_device_key: device,
self._op_role_key: self._op_role.Forward,
'peer': first_dev_index
......@@ -4137,7 +4138,7 @@ class PipelineOptimizer(object):
attrs={
self._op_device_key: prev_device_spec,
self._op_role_key: op_role,
'peer': prev_device_index
'peer': cur_device_index
})
extra_index += 1
block._insert_op(
......@@ -4146,9 +4147,10 @@ class PipelineOptimizer(object):
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device_spec,
self._op_role_key: op_role,
'peer': cur_device_index
'peer': prev_device_index
})
extra_index += 1
......@@ -4324,6 +4326,7 @@ class PipelineOptimizer(object):
outputs={'Out': [read_block.var(var_name)]},
attrs={
'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every
# microbatch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册