提交 63fb41b3 编写于 作者: D Dong Zhihong

"redefine the initop from kernel to OpBase"

上级 026c61c0
......@@ -125,7 +125,7 @@ class OperatorBase {
protected:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)opear
// I (Inputs)
// O (Outputs)
// OG (Output Gradients)
VariableNameMap inputs_;
......
......@@ -9,26 +9,30 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
// NCCLinitOp
class NCCLInitOp : public framework::OperatorWithKernel {
class NCCLInitOp : public framework::OperatorBase {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Communicator"),
" Output(Communicator) of ncclInitOp should not be NULL");
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus);
}
};
......@@ -188,13 +192,14 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclInit, ops::NCCLInitOp, ops::NCCLInitOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp,
ops::NCCLBcastSendOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp,
ops::NCCLBcastRecvOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
ops::NCCLReduceOpMaker);
REGISTER_OP_CPU_KERNEL(ncclInit, ops::NCCLInitKernel<float>);
......@@ -12,11 +12,30 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include <functional>
#include "paddle/operators/nccl_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Communicator;
template <typename Type>
class NCCLTypeWrapper;
template <>
class NCCLTypeWrapper<float> {
public:
static const ncclDataType_t type = ncclFloat;
};
template <>
class NCCLTypeWrapper<double> {
public:
static const ncclDataType_t type = ncclDouble;
};
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> {
public:
......
......@@ -11,7 +11,6 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl_op.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
......@@ -65,11 +64,11 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
TEST(NCCL, ncclInitOp) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp();
f::OpDescBind *op_desc = block->AppendOp();
op1->SetType("ncclInit");
op1->SetOutput("Communicator", {"x1"});
op1->SetAttr("gpus", {gpu_list});
op_desc->SetType("ncclInit");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
paddle::platform::DeviceContext *ctx =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
......@@ -77,7 +76,30 @@ TEST(NCCL, ncclInitOp) {
auto *var = g_scope.Var("x1");
var->GetMutable<paddle::platform::Communicator>();
auto op = f::OpRegistry::CreateOp(*op1);
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx);
VLOG(1) << "NCCLInitOp finished.";
}
// ncclAllReduceOp with desc
TEST(NCCL, ncclInitOp) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op_desc = block->AppendOp();
op_desc->SetType("ncclAllReduce");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
paddle::platform::DeviceContext *ctx =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
auto *var = g_scope.Var("x1");
var->GetMutable<paddle::platform::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, *ctx);
VLOG(1) << "NCCLInitOp finished.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册