提交 27f245cd 编写于 作者: S sandyhouse

fix alltoall ut, test=develop

上级 47f51e07
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_alltoall_op.h"
namespace paddle {
namespace operators {
class CAllToAllOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CAllToAll");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CAllToAll");
int ring_id = ctx->Attrs().Get<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_scatter_op must be non-negative.",
ring_id));
framework::DDim dim = ctx->GetInputDim("X");
if (dim[0] < 0) dim[0] = -1;
ctx->SetOutputDim("Out", dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class CAllToAllOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor send.");
AddOutput("Out", "(Tensor) the result of alltoall.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
CAllToAll Operator
Gather tensors from all participators to all participators.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_alltoall, ops::CAllToAllOp,
ops::CAllToAllOpMaker);
REGISTER_OP_CPU_KERNEL(c_alltoall, ops::CAllToAllOpCPUKernel<float>,
ops::CAllToAllOpCPUKernel<double>,
ops::CAllToAllOpCPUKernel<int>,
ops::CAllToAllOpCPUKernel<int64_t>,
ops::CAllToAllOpCPUKernel<plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_alltoall_op.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CAllToAllOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL)
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int send_numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for c_scatter_op must be non-negative.",
ring_id));
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims);
auto send_buf = x->data<T>();
auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0;
send_numel /= nranks;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel;
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
#else
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(c_alltoall, ops::CAllToAllOpCUDAKernel<float>,
ops::CAllToAllOpCUDAKernel<double>,
ops::CAllToAllOpCUDAKernel<int>,
ops::CAllToAllOpCUDAKernel<int64_t>,
ops::CAllToAllOpCUDAKernel<plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include <gloo/gather.h>
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CAllToAllOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto root_id = ctx.Attr<int>("root");
auto gloo = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"You must initialize the gloo environment first to use it."));
int64_t send_numel = in->numel();
int64_t recv_numel = out->numel();
auto nranks = gloo->Size();
auto rank = gloo->Rank();
T* recv_buff = out->data<T>();
T* send_buff = in->data<T>();
gloo::GatherOptions opts(gloo->GetContext());
opts.setOutput(recv_buff, recv_numel);
opts.setInput(send_buff, send_numel);
opts.setRoot(root_id);
gloo::alltoall(opts);
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
}
};
} // namespace operators
} // namespace paddle
...@@ -267,16 +267,10 @@ class TestDistBase(unittest.TestCase): ...@@ -267,16 +267,10 @@ class TestDistBase(unittest.TestCase):
elif col_type == "alltoall": elif col_type == "alltoall":
temp11, temp12 = np.split(input1, 2) temp11, temp12 = np.split(input1, 2)
temp21, temp22 = np.split(input2, 2) temp21, temp22 = np.split(input2, 2)
need_result1 = np.hstack((temp11, temp21)) need_result1 = np.vstack((temp11, temp21))
need_result2 = np.hstack((temp12, temp22)) need_result2 = np.vstack((temp12, temp22))
print("input1:", input1) self.assertTrue(np.allclose(tr0_out, need_result1))
print("input2:", input2) self.assertTrue(np.allclose(tr1_out, need_result2))
print("need_result1:", need_result1)
print("need_result2:", need_result2)
print("tr0_out:", tr0_out)
print("tr1_out:", tr1_out)
self.assertTrue(np.allclose(tr1_out, need_result1))
self.assertTrue(np.allclose(tr2_out, need_result2))
elif col_type == "reduce_scatter": elif col_type == "reduce_scatter":
tmp = input1 + input2 tmp = input1 + input2
need_result1 = tmp[0:tmp.shape[0] // 2] need_result1 = tmp[0:tmp.shape[0] // 2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册