提交 73883bde 编写于 作者: D Dong Zhihong

"fix error"

上级 d4d215a5
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
template <typename Type> template <typename Type>
class NCCLTypeWrapper; class NCCLTypeWrapper;
...@@ -21,7 +23,7 @@ class NCCLTypeWrapper<double> { ...@@ -21,7 +23,7 @@ class NCCLTypeWrapper<double> {
}; };
template <typename T> template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel { class NCCLAllReduceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
...@@ -35,13 +37,14 @@ class NCCLAllReduceKernel : public framework::OpKernel { ...@@ -35,13 +37,14 @@ class NCCLAllReduceKernel : public framework::OpKernel {
op_type = ncclProd; op_type = ncclProd;
} else if (reduction == "ncclMin") { } else if (reduction == "ncclMin") {
op_type = ncclMin; op_type = ncclMin;
} else } else if (reduction == "ncclMax") {
(reduction == "ncclMax") { op_type = ncclMax; } op_type = ncclMax;
}
auto dev_ctx = auto dev_ctx =
static_cast<const platform::CUDADeviceContext>(ctx.device_context()); static_cast<const platform::CUDADeviceContext>(ctx.device_context());
NCCLManager* m = NCCLManager::Get(); platform::NCCLManager* m = platform::NCCLManager::Get();
auto* comm = m->GetCommunicator(gpus); auto* comm = m->GetCommunicator(gpus);
comm->wg_.Add(1); comm->wg_.Add(1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册