提交 0b1052fc 编写于 作者: Y Yu Yang

Get `DEFINE_OPERATOR_CTOR` Back to code

上级 509d3209
...@@ -30,6 +30,7 @@ using DeviceContext = platform::DeviceContext; ...@@ -30,6 +30,7 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
}; };
...@@ -78,6 +79,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { ...@@ -78,6 +79,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class FcOp : public operators::NetOp { class FcOp : public operators::NetOp {
public: public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}}, {{"X", {Input("X")}}, {"Y", {Input("W")}}},
......
...@@ -10,6 +10,7 @@ namespace framework { ...@@ -10,6 +10,7 @@ namespace framework {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {} const platform::DeviceContext &dev_ctx) const override {}
......
...@@ -7,6 +7,7 @@ namespace paddle { ...@@ -7,6 +7,7 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
...@@ -27,6 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -27,6 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
......
...@@ -64,6 +64,17 @@ class ExecutionContext; ...@@ -64,6 +64,17 @@ class ExecutionContext;
*/ */
class OperatorBase { class OperatorBase {
public: public:
using VarNameMap = std::map<std::string, std::vector<std::string>>;
OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
OperatorBase(OperatorBase&& o) = delete;
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
...@@ -151,6 +162,15 @@ class OperatorBase { ...@@ -151,6 +162,15 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
...@@ -290,6 +310,8 @@ class OpKernel { ...@@ -290,6 +310,8 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
......
...@@ -22,6 +22,8 @@ namespace framework { ...@@ -22,6 +22,8 @@ namespace framework {
static int op_run_num = 0; static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
...@@ -102,6 +104,7 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -102,6 +104,7 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0; static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override {} void InferShape(const framework::InferShapeContext& ctx) const override {}
}; };
......
...@@ -18,6 +18,8 @@ namespace paddle { ...@@ -18,6 +18,8 @@ namespace paddle {
namespace operators { namespace operators {
class AddOp : public framework::OperatorWithKernel { class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(), PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
...@@ -43,6 +45,7 @@ The equation is: Out = X + Y ...@@ -43,6 +45,7 @@ The equation is: Out = X + Y
}; };
class AddOpGrad : public framework::OperatorWithKernel { class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
}; };
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel { class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
...@@ -31,6 +32,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -31,6 +32,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
}; };
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
...@@ -18,6 +18,8 @@ namespace paddle { ...@@ -18,6 +18,8 @@ namespace paddle {
namespace operators { namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Dst")->Resize( ctx.Output<framework::Tensor>("Dst")->Resize(
......
...@@ -43,6 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel { ...@@ -43,6 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
}; };
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>(0);
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class MeanOp : public framework::OperatorWithKernel { class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
...@@ -37,6 +38,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -37,6 +38,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class MeanGradOp : public framework::OperatorWithKernel { class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X")) ctx.Output<Tensor>(framework::GradVarName("X"))
......
...@@ -18,6 +18,8 @@ namespace paddle { ...@@ -18,6 +18,8 @@ namespace paddle {
namespace operators { namespace operators {
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel);
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims(); auto dim0 = ctx.Input<Tensor>("X")->dims();
...@@ -51,6 +53,7 @@ The equation is: Out = X * Y ...@@ -51,6 +53,7 @@ The equation is: Out = X * Y
}; };
class MulOpGrad : public framework::OperatorWithKernel { class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override { std::string DebugString() const override {
......
...@@ -37,6 +37,7 @@ namespace operators { ...@@ -37,6 +37,7 @@ namespace operators {
class NetOp : public framework::OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
static const char kAll[]; static const char kAll[];
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase);
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
......
...@@ -12,6 +12,7 @@ static int run_cnt = 0; ...@@ -12,6 +12,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase { class TestOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase);
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
...@@ -21,6 +22,7 @@ class TestOp : public framework::OperatorBase { ...@@ -21,6 +22,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase);
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
}; };
......
...@@ -101,6 +101,8 @@ class RecurrentGradientAlgorithm { ...@@ -101,6 +101,8 @@ class RecurrentGradientAlgorithm {
class RecurrentOp final : public framework::OperatorBase { class RecurrentOp final : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase);
void Init() override; void Init() override;
/** /**
...@@ -123,6 +125,7 @@ class RecurrentOp final : public framework::OperatorBase { ...@@ -123,6 +125,7 @@ class RecurrentOp final : public framework::OperatorBase {
class RecurrentGradientOp final : public framework::OperatorBase { class RecurrentGradientOp final : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(RecurrentGradientOp, framework::OperatorBase)
void Init() override; void Init() override;
/** /**
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel { class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims(); auto dim0 = ctx.Input<Tensor>("X")->dims();
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SGDOp : public framework::OperatorWithKernel { class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public framework::OperatorWithKernel { class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims()); ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
...@@ -36,6 +37,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,6 +37,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel { class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel { class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL, PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
...@@ -38,6 +39,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,6 +39,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel { class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
......
...@@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel {
}; };
class UniformRandomOp : public framework::OperatorWithKernel { class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"), PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册