提交 f4456ffe 编写于 作者: S sandyhouse

update code, test=develop

上级 a566c12e
......@@ -31,7 +31,7 @@ class CRecvOp : public framework::OperatorWithKernel {
}
};
class CSendOpMaker : public framework::OpProtoAndCheckerMaker {
class CRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("Out", "(Tensor) tensor to receive.");
......
......@@ -23,11 +23,11 @@ namespace paddle {
namespace operators {
template <typename T>
class CSendOpCUDAKernel : public framework::OpKernel<T> {
class CRecvOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
auto out = ctx.Output<framework::LoDTensor>("Out");
auto out = ctx.Input<framework::LoDTensor>("Out");
int numel = out->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(out->type());
......@@ -44,8 +44,9 @@ class CSendOpCUDAKernel : public framework::OpKernel<T> {
}
int peer = ctx.Attr<int>("peer");
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
out->mutable_data<T>(place), numel, dtype, peer, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclRecv(const_cast<T*>(out->data<T>()), numel,
dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
<< framework::product(out->dims()) << " from " << peer;
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册