提交 73883bde 编写于 作者: D Dong Zhihong

"fix error"

上级 d4d215a5
......@@ -7,6 +7,8 @@
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename Type>
class NCCLTypeWrapper;
......@@ -21,7 +23,7 @@ class NCCLTypeWrapper<double> {
};
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel {
class NCCLAllReduceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X");
......@@ -35,13 +37,14 @@ class NCCLAllReduceKernel : public framework::OpKernel {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else
(reduction == "ncclMax") { op_type = ncclMax; }
} else if (reduction == "ncclMax") {
op_type = ncclMax;
}
auto dev_ctx =
static_cast<const platform::CUDADeviceContext>(ctx.device_context());
NCCLManager* m = NCCLManager::Get();
platform::NCCLManager* m = platform::NCCLManager::Get();
auto* comm = m->GetCommunicator(gpus);
comm->wg_.Add(1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册