提交 38f4b1d5 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #3430 from wangkuiyi/add_operatorbase_constructors

Add constructors to OperatorBase and all sub-classes
......@@ -30,6 +30,8 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase)
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
......
......@@ -10,6 +10,8 @@ namespace framework {
class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase)
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
......
......@@ -7,6 +7,8 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase)
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
......@@ -27,6 +29,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase)
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
......
......@@ -63,6 +63,17 @@ class ExecutionContext;
*/
class OperatorBase {
public:
OperatorBase() {} // TODO(yi): This constructor is to be removed.
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
in_out_idxs_(in_out_idxs) {}
virtual ~OperatorBase() {}
template <typename T>
......@@ -109,6 +120,9 @@ class OperatorBase {
const std::vector<std::string> Inputs() const { return inputs_; }
const std::vector<std::string> Outputs() const { return outputs_; }
const AttributeMap& Attrs() const { return attrs_; }
const std::unordered_map<std::string, int>* InOutIdx() const {
return in_out_idxs_.get();
}
public:
std::string type_;
......@@ -286,6 +300,14 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase {
public:
OperatorWithKernel() {} // TODO(yi): This constructor is to be removed.
OperatorWithKernel(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
struct OpKernelKey {
platform::Place place_;
......@@ -335,5 +357,15 @@ class OperatorWithKernel : public OperatorBase {
virtual void InferShape(const InferShapeContext& ctx) const = 0;
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const std::vector<std::string>& inputs, \
const std::vector<std::string>& outputs, \
const ::paddle::framework::AttributeMap& attrs, \
std::unordered_map<std::string, int>* in_out_idxs) \
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
} // namespace framework
} // namespace paddle
......@@ -23,6 +23,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase)
void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
......@@ -97,6 +99,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel {
public:
DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
......@@ -116,6 +120,8 @@ class CPUKernelTest : public OpKernel {
// multiple inputs test
class OperatorMultiInputsTest : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase)
void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
......@@ -47,6 +48,7 @@ The equation is: Out = X + Y
};
class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
......@@ -38,6 +39,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
......
......@@ -43,6 +43,7 @@ class GaussianRandomKernel : public framework::OpKernel {
};
class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
......@@ -39,6 +40,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};
class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
......@@ -53,6 +54,7 @@ The equation is: Out = X * Y
};
class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
......
......@@ -35,6 +35,8 @@ namespace operators {
*/
class NetOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase)
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
......
......@@ -12,6 +12,8 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase)
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -21,6 +23,8 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase)
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
......
......@@ -100,6 +100,7 @@ class RecurrentGradientAlgorithm {
};
class RecurrentOp final : public framework::OperatorBase {
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase)
public:
void Init() override;
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL,
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two");
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
......@@ -38,6 +39,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
......@@ -42,6 +43,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
......
......@@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel {
};
class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册