提交 63fb41b3 编写于 作者: D Dong Zhihong

"redefine the initop from kernel to OpBase"

上级 026c61c0
...@@ -125,7 +125,7 @@ class OperatorBase { ...@@ -125,7 +125,7 @@ class OperatorBase {
protected: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)opear // I (Inputs)
// O (Outputs) // O (Outputs)
// OG (Output Gradients) // OG (Output Gradients)
VariableNameMap inputs_; VariableNameMap inputs_;
......
...@@ -9,26 +9,30 @@ ...@@ -9,26 +9,30 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/nccl_op.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// NCCLinitOp // NCCLinitOp
class NCCLInitOp : public framework::OperatorWithKernel { class NCCLInitOp : public framework::OperatorBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
protected: const framework::AttributeMap &attrs)
void InferShape(framework::InferShapeContext *ctx) const override { : OperatorBase(type, inputs, outputs, attrs) {}
PADDLE_ENFORCE(ctx->HasOutput("Communicator"),
" Output(Communicator) of ncclInitOp should not be NULL"); void Run(const framework::Scope &scope,
} const platform::DeviceContext &dev_ctx) const override {
const auto &name = Output("Communicator");
protected: PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
framework::DataType IndicateDataType( "Can not find variable '%s' in the scope.", name);
const framework::ExecutionContext &ctx) const override { std::vector<int> gpus = Attr<std::vector<int>>("gpus");
return static_cast<framework::DataType>(ctx.Attr<int>("data_type")); PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus);
} }
}; };
...@@ -188,13 +192,14 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -188,13 +192,14 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker); ops::NCCLAllReduceOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclInit, ops::NCCLInitOp, ops::NCCLInitOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp, REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp,
ops::NCCLBcastSendOpMaker); ops::NCCLBcastSendOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp, REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp,
ops::NCCLBcastRecvOpMaker); ops::NCCLBcastRecvOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp, REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
ops::NCCLReduceOpMaker); ops::NCCLReduceOpMaker);
REGISTER_OP_CPU_KERNEL(ncclInit, ops::NCCLInitKernel<float>);
...@@ -12,11 +12,30 @@ limitations under the License. */ ...@@ -12,11 +12,30 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <functional> #include <functional>
#include "paddle/operators/nccl_op.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
using platform::Communicator;
template <typename Type>
class NCCLTypeWrapper;
template <>
class NCCLTypeWrapper<float> {
public:
static const ncclDataType_t type = ncclFloat;
};
template <>
class NCCLTypeWrapper<double> {
public:
static const ncclDataType_t type = ncclDouble;
};
template <typename T> template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> { class NCCLAllReduceKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/nccl_op.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -65,11 +64,11 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -65,11 +64,11 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
TEST(NCCL, ncclInitOp) { TEST(NCCL, ncclInitOp) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op_desc = block->AppendOp();
op1->SetType("ncclInit"); op_desc->SetType("ncclInit");
op1->SetOutput("Communicator", {"x1"}); op_desc->SetOutput("Communicator", {"x1"});
op1->SetAttr("gpus", {gpu_list}); op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope; f::Scope g_scope;
paddle::platform::DeviceContext *ctx = paddle::platform::DeviceContext *ctx =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
...@@ -77,7 +76,30 @@ TEST(NCCL, ncclInitOp) { ...@@ -77,7 +76,30 @@ TEST(NCCL, ncclInitOp) {
auto *var = g_scope.Var("x1"); auto *var = g_scope.Var("x1");
var->GetMutable<paddle::platform::Communicator>(); var->GetMutable<paddle::platform::Communicator>();
auto op = f::OpRegistry::CreateOp(*op1); auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx);
VLOG(1) << "NCCLInitOp finished.";
}
// ncclAllReduceOp with desc
TEST(NCCL, ncclInitOp) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op_desc = block->AppendOp();
op_desc->SetType("ncclAllReduce");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
paddle::platform::DeviceContext *ctx =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
auto *var = g_scope.Var("x1");
var->GetMutable<paddle::platform::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp."; VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx); op->Run(g_scope, *ctx);
VLOG(1) << "NCCLInitOp finished."; VLOG(1) << "NCCLInitOp finished.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册