未验证 提交 b2f7ab66 编写于 作者: L lilong12 提交者: GitHub

bug fix, test=develop (#28648)

上级 8f2656ef
......@@ -26,6 +26,7 @@ template <typename T>
class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
rid, 0,
......@@ -44,7 +45,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
cudaStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
......
......@@ -26,6 +26,7 @@ template <typename T>
class SendOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();
......@@ -42,7 +43,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
"The peer (%d) for send_v2 op must be non-negative.", peer));
cudaStream_t stream = nullptr;
auto place = ctx.GetPlace();
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册