提交 ffbb0be2 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3444 from reyoung/use_ctor_create_op

Using constructor to create an operator.
...@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext; ...@@ -30,7 +30,7 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase); using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
}; };
...@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { ...@@ -79,8 +79,9 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class FcOp : public operators::NetOp { class FcOp : public operators::NetOp {
public: public:
DEFINE_OPERATOR_CTOR(FcOp, operators::NetOp) FcOp(const std::string &type, const VarNameMap &inputs,
void Init() override { const VarNameMap &outputs, const AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(OpRegistry::CreateOp("mul", AddOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}}, {{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {})); {{"Out", {Output("mul_result")}}}, {}));
......
...@@ -20,13 +20,12 @@ namespace paddle { ...@@ -20,13 +20,12 @@ namespace paddle {
namespace framework { namespace framework {
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, static void TransOpArg(const OperatorBase* src_op,
const OpArgType& src_type, const OpArgType& dst_type, OperatorBase::VarNameMap* vars,
bool is_grad) { const OpArgType& src_type, bool is_grad) {
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
auto& dst_inout = auto& dst_inout = *vars;
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
const OpProto& proto = OpProtos().at(src_op->type_); const OpProto& proto = OpProtos().at(src_op->type_);
const auto& src_arg_list = const auto& src_arg_list =
...@@ -44,15 +43,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -44,15 +43,22 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
grad_op->type_ = grad_op_type; "Operator %s do not register gradient type", op->type_);
grad_op->attrs_ = op->attrs_; auto& grad_op_type = gop_type_it->second;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I OperatorBase::VarNameMap inputs;
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O OperatorBase::VarNameMap outputs;
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG TransOpArg(op, &inputs, OpArgType::IN, false); // I
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG TransOpArg(op, &inputs, OpArgType::OUT, false); // O
return grad_op; TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->type_, grad_op_type);
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
} }
} // namespace framework } // namespace framework
......
...@@ -10,7 +10,7 @@ namespace framework { ...@@ -10,7 +10,7 @@ namespace framework {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase); using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {} void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {} const platform::DeviceContext &dev_ctx) const override {}
......
...@@ -120,13 +120,19 @@ class OpProtoAndCheckerMaker { ...@@ -120,13 +120,19 @@ class OpProtoAndCheckerMaker {
}; };
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarNameMap = OperatorBase::VarNameMap; using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
op_creators()[op_type] = [] { return new OpType; }; op_creators()[op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = OpProtos()[op_type]; OpProto& op_proto = OpProtos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
...@@ -141,29 +147,25 @@ class OpRegistry { ...@@ -141,29 +147,25 @@ class OpRegistry {
template <typename GradOpType> template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type, static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) { const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; }; op_creators()[grad_op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new GradOpType(type, inputs, outputs, attrs);
};
grad_ops()[op_type] = grad_op_type; grad_ops()[op_type] = grad_op_type;
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs, const VarNameMap& inputs,
const VarNameMap& outputs, const VarNameMap& outputs,
const AttributeMap& attrs) { AttributeMap attrs) {
auto op_create_it = op_creators().find(type); auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type); "Operator %s cannot be found.", type);
op_checkers().at(type).Check(attrs);
auto op = op_create_it->second(); auto op = op_create_it->second(type, inputs, outputs, attrs);
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
GenerateTempVariableName(op);
op->Init();
return std::shared_ptr<OperatorBase>(op); return std::shared_ptr<OperatorBase>(op);
} }
...@@ -195,7 +197,6 @@ class OpRegistry { ...@@ -195,7 +197,6 @@ class OpRegistry {
PADDLE_ENFORCE(!op.IsNetOp(), PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops"); "Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op)); std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
grad_op->Init();
return grad_op; return grad_op;
} }
...@@ -214,19 +215,6 @@ class OpRegistry { ...@@ -214,19 +215,6 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_; return op_checkers_;
} }
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += op->type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
}; };
class Registrar { class Registrar {
......
...@@ -7,7 +7,7 @@ namespace paddle { ...@@ -7,7 +7,7 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase); using OperatorBase::OperatorBase;
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
...@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -28,7 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase); using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
......
...@@ -122,6 +122,23 @@ void OperatorBase::Rename(const std::string& old_name, ...@@ -122,6 +122,23 @@ void OperatorBase::Rename(const std::string& old_name,
} }
} }
OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& inputs,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val; std::vector<std::string> ret_val;
if (has_intermediate) { if (has_intermediate) {
......
...@@ -66,10 +66,8 @@ class OperatorBase { ...@@ -66,10 +66,8 @@ class OperatorBase {
public: public:
using VarNameMap = std::map<std::string, std::vector<std::string>>; using VarNameMap = std::map<std::string, std::vector<std::string>>;
OperatorBase() = default;
OperatorBase(const std::string& type, const VarNameMap& inputs, OperatorBase(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) const VarNameMap& outputs, const AttributeMap& attrs);
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
OperatorBase(const OperatorBase& o) = delete; OperatorBase(const OperatorBase& o) = delete;
OperatorBase& operator=(const OperatorBase& o) = delete; OperatorBase& operator=(const OperatorBase& o) = delete;
...@@ -86,10 +84,6 @@ class OperatorBase { ...@@ -86,10 +84,6 @@ class OperatorBase {
virtual std::string DebugString() const; virtual std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const Scope& scope) const = 0; virtual void InferShape(const Scope& scope) const = 0;
...@@ -135,15 +129,6 @@ class OperatorBase { ...@@ -135,15 +129,6 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
...@@ -287,8 +272,6 @@ class OpKernel { ...@@ -287,8 +272,6 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
...@@ -312,6 +295,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -312,6 +295,10 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const Scope& scope) const override { void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope)); InferShape(InferShapeContext(*this, scope));
} }
......
...@@ -22,10 +22,10 @@ namespace framework { ...@@ -22,10 +22,10 @@ namespace framework {
static int op_run_num = 0; static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, framework::OperatorBase)
public: public:
void Init() override { x = 1; } OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
...@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -38,7 +38,7 @@ class OpWithoutKernelTest : public OperatorBase {
} }
public: public:
float x = 0; int x{0};
}; };
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -109,7 +109,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -109,7 +109,9 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0; static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OpWithKernelTest, framework::OperatorWithKernel) public:
using OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override {} void InferShape(const framework::InferShapeContext& ctx) const override {}
}; };
......
...@@ -18,7 +18,8 @@ namespace paddle { ...@@ -18,7 +18,8 @@ namespace paddle {
namespace operators { namespace operators {
class AddOp : public framework::OperatorWithKernel { class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
...@@ -45,7 +46,9 @@ The equation is: Out = X + Y ...@@ -45,7 +46,9 @@ The equation is: Out = X + Y
}; };
class AddOpGrad : public framework::OperatorWithKernel { class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
}; };
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel { class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
...@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -32,8 +34,9 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
}; };
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp, public:
framework::OperatorWithKernel) using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
...@@ -18,7 +18,8 @@ namespace paddle { ...@@ -18,7 +18,8 @@ namespace paddle {
namespace operators { namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel); public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
......
...@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel { ...@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
}; };
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel); public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
class MeanOp : public framework::OperatorWithKernel { class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
...@@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,7 +40,9 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class MeanGradOp : public framework::OperatorWithKernel { class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X")) ctx.Output<Tensor>(framework::GradVarName("X"))
......
...@@ -19,7 +19,8 @@ namespace paddle { ...@@ -19,7 +19,8 @@ namespace paddle {
namespace operators { namespace operators {
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel); public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
...@@ -54,7 +55,9 @@ The equation is: Out = X * Y ...@@ -54,7 +55,9 @@ The equation is: Out = X * Y
}; };
class MulOpGrad : public framework::OperatorWithKernel { class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override { std::string DebugString() const override {
......
...@@ -81,5 +81,11 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const { ...@@ -81,5 +81,11 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
return ret_val; return ret_val;
} }
NetOp::NetOp(const std::string& type,
const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -37,7 +37,9 @@ namespace operators { ...@@ -37,7 +37,9 @@ namespace operators {
class NetOp : public framework::OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
static const char kAll[]; static const char kAll[];
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase); NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {}
NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
......
...@@ -12,7 +12,7 @@ static int run_cnt = 0; ...@@ -12,7 +12,7 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase { class TestOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase); using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
...@@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase { ...@@ -22,7 +22,7 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase { class EmptyOp : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase); using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
}; };
...@@ -44,14 +44,14 @@ TEST(OpKernel, all) { ...@@ -44,14 +44,14 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::shared_ptr<TestOp>(
op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}; new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
op1->outputs_ = {{"Out", {"y"}}}; {{"Out", {"y"}}}, {}));
net->AddOp(op1); net->AddOp(op1);
auto op2 = std::make_shared<TestOp>(); auto op2 = std::shared_ptr<TestOp>(
op2->inputs_ = {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}; new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
op2->outputs_ = {{"Out", {"z"}}}; {{"Out", {"z"}}}, {}));
net->AddOp(op2); net->AddOp(op2);
net->CompleteAddOp(); net->CompleteAddOp();
...@@ -67,9 +67,9 @@ TEST(OpKernel, all) { ...@@ -67,9 +67,9 @@ TEST(OpKernel, all) {
TEST(NetOp, insert_op) { TEST(NetOp, insert_op) {
NetOp net; NetOp net;
auto op1 = std::make_shared<EmptyOp>(); auto op1 = std::shared_ptr<EmptyOp>(
op1->inputs_ = {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}; new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
op1->outputs_ = {{"Out", {"y"}}}; {{"Out", {"y"}}}, {}));
net.AddOp(op1); net.AddOp(op1);
net.InsertOp(0, op1); net.InsertOp(0, op1);
ASSERT_EQ(2UL, net.ops_.size()); ASSERT_EQ(2UL, net.ops_.size());
......
...@@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{ ...@@ -135,8 +135,11 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
"inlink@grad", "inlink_alias", "outlink_alias", "inlink@grad", "inlink_alias", "outlink_alias",
"memories", "pre_memories", "boot_memories@grad"}; "memories", "pre_memories", "boot_memories@grad"};
void RecurrentOp::Init() { RecurrentOp::RecurrentOp(const std::string& type,
OperatorBase::Init(); const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument()); std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this); rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg)); alg_.Init(std::move(arg));
...@@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ...@@ -230,8 +233,11 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
} }
void RecurrentGradientOp::Init() { RecurrentGradientOp::RecurrentGradientOp(
OperatorBase::Init(); const std::string& type, const framework::OperatorBase::VarNameMap& inputs,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument()); std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this); rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg)); alg_.Init(std::move(arg));
......
...@@ -101,10 +101,8 @@ class RecurrentGradientAlgorithm { ...@@ -101,10 +101,8 @@ class RecurrentGradientAlgorithm {
class RecurrentOp final : public framework::OperatorBase { class RecurrentOp final : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase); RecurrentOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
void Init() override;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
...@@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase { ...@@ -125,8 +123,9 @@ class RecurrentOp final : public framework::OperatorBase {
class RecurrentGradientOp final : public framework::OperatorBase { class RecurrentGradientOp final : public framework::OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(RecurrentGradientOp, framework::OperatorBase) RecurrentGradientOp(const std::string& type, const VarNameMap& inputs,
void Init() override; const VarNameMap& outputs,
const framework::AttributeMap& attrs);
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel { class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims(); auto dim0 = ctx.Input<Tensor>("X")->dims();
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
class SGDOp : public framework::OperatorWithKernel { class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public framework::OperatorWithKernel { class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims()); ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
...@@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -37,7 +39,9 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel { class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace operators { namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel { class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL, PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
...@@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -39,7 +41,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel { class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
......
...@@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -46,7 +46,9 @@ class CPUUniformRandomKernel : public framework::OpKernel {
}; };
class UniformRandomOp : public framework::OperatorWithKernel { class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"), PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册