提交 0b1052fc 编写于 作者: Y Yu Yang

Get `DEFINE_OPERATOR_CTOR` Back to code

上级 509d3209
......@@ -30,6 +30,7 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase);
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
......@@ -78,6 +79,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class FcOp : public operators::NetOp {
public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp)
void Init() override {
AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
......
......@@ -10,6 +10,7 @@ namespace framework {
class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase);
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
......
......@@ -7,6 +7,7 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase);
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
......@@ -27,6 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase);
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
......
......@@ -64,6 +64,17 @@ class ExecutionContext;
*/
class OperatorBase {
public:
using VarNameMap = std::map<std::string, std::vector<std::string>>;
OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete;
OperatorBase(OperatorBase&& o) = delete;
virtual ~OperatorBase() {}
template <typename T>
......@@ -151,6 +162,15 @@ class OperatorBase {
AttributeMap attrs_;
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class InferShapeContext {
public:
InferShapeContext(const OperatorBase& op, const Scope& scope)
......@@ -290,6 +310,8 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
struct OpKernelKey {
platform::Place place_;
......
......@@ -22,6 +22,8 @@ namespace framework {
static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
public:
void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {}
......@@ -102,6 +104,7 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
......
......@@ -18,6 +18,8 @@ namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
......@@ -43,6 +45,7 @@ The equation is: Out = X + Y
};
class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
......@@ -31,6 +32,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -18,6 +18,8 @@ namespace paddle {
namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel);
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Dst")->Resize(
......
......@@ -43,6 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};
class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel);
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
......@@ -37,6 +38,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};
class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
......
......@@ -18,6 +18,8 @@ namespace paddle {
namespace operators {
class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel);
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
......@@ -51,6 +53,7 @@ The equation is: Out = X * Y
};
class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
......
......@@ -37,6 +37,7 @@ namespace operators {
class NetOp : public framework::OperatorBase {
public:
static const char kAll[];
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase);
/**
* Infer all the operators' input and output variables' shapes, will be called
......
......@@ -12,6 +12,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase);
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -21,6 +22,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase);
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
......
......@@ -101,6 +101,8 @@ class RecurrentGradientAlgorithm {
class RecurrentOp final : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase);
void Init() override;
/**
......@@ -123,6 +125,7 @@ class RecurrentOp final : public framework::OperatorBase {
class RecurrentGradientOp final : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(RecurrentGradientOp, framework::OperatorBase)
void Init() override;
/**
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
......@@ -36,6 +37,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
......@@ -38,6 +39,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
......
......@@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel {
};
class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册