提交 0e78cb69 编写于 作者: Y Yu Yang

Clean OpProtoAndCheckerMaker

Do not use ctor

* Reduce line of codes.
* We can use virtual function for Maker now.
* The implementation does not care what maker holds, it is easier to
refactor later.
上级 2a22da6c
...@@ -32,8 +32,7 @@ struct AddFunctor { ...@@ -32,8 +32,7 @@ struct AddFunctor {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op"); AddInput("input", "input1 of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false); AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
......
...@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new proto::OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_); T maker;
maker.SetProto(info->proto_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate(); maker.Validate();
info->proto_->set_type(op_type); info->proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -23,20 +23,21 @@ namespace framework { ...@@ -23,20 +23,21 @@ namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
using OpProto = proto::OpProto; virtual void Make() = 0;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual ~OpProtoAndCheckerMaker() { virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build"); PADDLE_ENFORCE(validated_, "should call Validate after build");
} }
void SetProto(proto::OpProto* proto) { proto_ = proto; }
void SetChecker(OpAttrChecker* attr_checker) { op_checker_ = attr_checker; }
void Validate(); void Validate();
protected: protected:
struct VariableBuilder { struct VariableBuilder {
OpProto::Var* var_; proto::OpProto::Var* var_;
VariableBuilder& AsDuplicable() { VariableBuilder& AsDuplicable() {
var_->set_duplicable(true); var_->set_duplicable(true);
...@@ -76,16 +77,9 @@ class OpProtoAndCheckerMaker { ...@@ -76,16 +77,9 @@ class OpProtoAndCheckerMaker {
private: private:
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
OpProto* proto_; proto::OpProto* proto_;
OpAttrChecker* op_checker_; OpAttrChecker* op_checker_;
bool validated_{false}; bool validated_{false};
}; };
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,9 +18,7 @@ limitations under the License. */ ...@@ -18,9 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto, void Make() {
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op"); AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op"); AddAttr<float>("scale", "scale of test op");
} }
...@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedAttr) { TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto, void Make() {
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddInput("input", "input of test op"); AddInput("input", "input of test op");
} }
...@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedInOut) { TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase { ...@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op"); AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op"); AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
...@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase { ...@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").AsDuplicable(); AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate(); AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) { auto my_checker = [](int i) {
...@@ -212,10 +210,7 @@ namespace framework { ...@@ -212,10 +210,7 @@ namespace framework {
class OpKernelTestMaker : public OpProtoAndCheckerMaker { class OpKernelTestMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() { AddComment("NoGradOp, same input output. no Grad"); }
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("NoGradOp, same input output. no Grad");
}
}; };
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
...@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) { ...@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
static int op_test_value = 0; static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext; using paddle::platform::CUDADeviceContext;
using paddle::platform::DeviceContext;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker { class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op"); AddAttr<float>("scale", "scale of cosine op");
...@@ -98,8 +97,7 @@ namespace framework { ...@@ -98,8 +97,7 @@ namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "input of test op"); AddInput("x", "input of test op");
AddOutput("y", "output of test op"); AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
...@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> { ...@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
class OpKernelTestMultiInputsProtoAndCheckerMaker class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker { : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, void Make() {
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("xs", "inputs of test op").AsDuplicable(); AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op"); AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").AsDuplicable(); AddOutput("ys", "outputs of test op").AsDuplicable();
......
...@@ -24,8 +24,7 @@ namespace framework { ...@@ -24,8 +24,7 @@ namespace framework {
class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public OpProtoAndCheckerMaker {
public: public:
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "").AsDuplicable(); AddInput("X", "").AsDuplicable();
AddOutput("Out", ""); AddOutput("Out", "");
AddComment(""); AddComment("");
......
...@@ -166,6 +166,8 @@ function(op_library TARGET) ...@@ -166,6 +166,8 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually. # NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation") if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif() endif()
......
...@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
// TODO(typhoonzero): support both inference value and indices. // TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)"); AddInput("Out", "The network output of topk (inferences)");
AddInput("Indices", "The the network output of topk (indices)"); AddInput("Indices", "The the network output of topk (indices)");
......
...@@ -23,8 +23,7 @@ namespace operators { ...@@ -23,8 +23,7 @@ namespace operators {
class OP_NAME##OpMaker \ class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \ public: \
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \ void Make() override { \
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
AddInput("X", "Input of " #OP_NAME "operator"); \ AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \ AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \ AddAttr<bool>("use_mkldnn", \
...@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$ ...@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator"); AddInput("X", "Input of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator"); AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f); AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
...@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$ ...@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator"); AddInput("X", "Input of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator"); AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f); AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
...@@ -242,8 +239,7 @@ $$ ...@@ -242,8 +239,7 @@ $$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator"); AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator"); AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink") AddAttr<float>("threshold", "The value of threshold for HardShrink")
...@@ -265,8 +261,7 @@ $$ ...@@ -265,8 +261,7 @@ $$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker { class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator"); AddInput("X", "Input of BRelu operator");
AddOutput("Out", "Output of BRelu operator"); AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu") AddAttr<float>("t_min", "The min marginal value of BRelu")
...@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$ ...@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator"); AddInput("X", "Input of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator"); AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu") AddAttr<float>("threshold", "The threshold value of SoftRelu")
...@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ ...@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker { class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator"); AddInput("X", "Input of ELU operator");
AddOutput("Out", "Output of ELU operator"); AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f); AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
...@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$ ...@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator"); AddInput("X", "Input of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator"); AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6") AddAttr<float>("threshold", "The threshold value of Relu6")
...@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$ ...@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker { class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator"); AddInput("X", "Input of Pow operator");
AddOutput("Out", "Output of Pow operator"); AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f); AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
...@@ -353,8 +344,7 @@ $out = x^{factor}$ ...@@ -353,8 +344,7 @@ $out = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker { class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator"); AddInput("X", "Input of STanh operator");
AddOutput("Out", "Output of STanh operator"); AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input") AddAttr<float>("scale_a", "The scale parameter of a for the input")
...@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ ...@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator"); AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator"); AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation") AddAttr<float>("threshold", "The threshold location of activation")
...@@ -394,8 +383,7 @@ $$ ...@@ -394,8 +383,7 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator"); AddInput("X", "Input of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator"); AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid") AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
...@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation. ...@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker { class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator"); AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator"); AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f); AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
...@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient"); AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
......
...@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment"); AddInput("Moment", "(Tensor) Second moment");
......
...@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker { class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate"); AddInput("LearningRate", "(Tensor) Learning rate");
......
...@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate"); AddInput("LearningRate", "(Tensor) Learning rate");
......
...@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to " "(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor."); "be casted to a big LoDTensor.");
......
...@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase { ...@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable " "(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.") "could be LoDTensor, SelectedRows or LoDTensorArray.")
......
...@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AssignValueOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) Output tensor of assign_value operator."); AddOutput("Out", "(Tensor) Output tensor of assign_value operator.");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
"(vector<int>) " "(vector<int>) "
......
...@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker { class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AucOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Out", AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]." "A floating point 2D tensor, values are in the range [0, 1]."
"Each row is sorted in descending order. This input should be the" "Each row is sorted in descending order. This input should be the"
......
...@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "(Tensor), The parameter to be accumulated."); AddInput("param", "(Tensor), The parameter to be accumulated.");
AddInput("in_sum_1", AddInput("in_sum_1",
"(Tensor), A tensor used to store the parameter " "(Tensor), A tensor used to store the parameter "
......
...@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<bool>("is_test", "").SetDefault(false); AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9); AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "") AddAttr<float>("epsilon", "")
......
...@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker { class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) Tensor " "(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size"); "whose input_dim_idx'th dimension specifies the batch_size");
......
...@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker { class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", AddInput("Ids",
"(LodTensorArray)" "(LodTensorArray)"
"score of the candidate words in each step"); "score of the candidate words in each step");
......
...@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) { ...@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step"); AddInput("pre_ids", "ids in previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]"); AddInput("ids", "a LoDTensor of shape of [None,k]");
......
...@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel { ...@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of bilinear interpolation, " "(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"); "This is a 4-D tensor with shape of (N x C x h x w)");
......
...@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of bilinear_tensor_product operator."); AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator."); AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight", AddInput("Weight",
......
...@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"DistMat", "DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
......
...@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"PriorBox", "PriorBox",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -21,8 +21,7 @@ namespace operators { ...@@ -21,8 +21,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of cast op"); AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op"); AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type"); AddAttr<int>("out_dtype", "output data type");
......
...@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase { ...@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kChannel, AddInput(kChannel,
"The Channel Variable that should be closed by" "The Channel Variable that should be closed by"
" the ChannelClose Op."); " the ChannelClose Op.");
......
...@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase { ...@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput(kOutput, AddOutput(kOutput,
"The object of a Channel type created by ChannelCreate Op."); "The object of a Channel type created by ChannelCreate Op.");
AddAttr<int>("capacity", "The size of the buffer of Channel.") AddAttr<int>("capacity", "The size of the buffer of Channel.")
......
...@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase { ...@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase {
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelRecvOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(Channel, AddInput(Channel,
"(Channel) A variable which \"receives\" the a value sent" "(Channel) A variable which \"receives\" the a value sent"
"to it by a channel_send op.") "to it by a channel_send op.")
......
...@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase { ...@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase {
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelSendOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(Channel, AddInput(Channel,
"(Channel) A variable which \"sends\" the passed in value to " "(Channel) A variable which \"sends\" the passed in value to "
"a listening receiver.") "a listening receiver.")
......
...@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference", AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). " "(Tensor, default: Tensor<int64_t>). "
"Predictions from the network."); "Predictions from the network.");
......
...@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input of clip_by_norm op." "(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9]."); "The number of dimensions must be between [1, 9].");
......
...@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker { class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor)The input of clip op." "(Tensor)The input of clip op."
"The number of dimensions must be between [1, 9]."); "The number of dimensions must be between [1, 9].");
......
...@@ -21,8 +21,7 @@ namespace operators { ...@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment> template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator", string::Sprintf("(LoDTensor) the left hand operand of %s operator",
......
...@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator."); AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("axis", AddAttr<int>("axis",
......
...@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The conditional variable of this operator. If X is empty, the " "The conditional variable of this operator. If X is empty, the "
"whole sub-block will not be executed.") "whole sub-block will not be executed.")
......
...@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
library); library);
} }
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Conv2DOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
...@@ -200,8 +199,7 @@ $$ ...@@ -200,8 +199,7 @@ $$
)DOC"); )DOC");
} }
Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Conv3DOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
......
...@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim, ...@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class ConvOp : public framework::OperatorWithKernel { class ConvOp : public framework::OperatorWithKernel {
......
...@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, " "(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension."); "where B is the batch size and M is the data dimension.");
......
...@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
layout_, library_); layout_, library_);
} }
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto, void Conv2DTransposeOpMaker::Make() {
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution transpose operator. " "(Tensor) The input tensor of convolution transpose operator. "
...@@ -168,9 +166,7 @@ Example: ...@@ -168,9 +166,7 @@ Example:
)DOC"); )DOC");
} }
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto, void Conv3DTransposeOpMaker::Make() {
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator." "(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is " "The format of input tensor is NCDHW. Where N is batch size, C is "
......
...@@ -30,12 +30,12 @@ using DDim = framework::DDim; ...@@ -30,12 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class ConvTransposeOp : public framework::OperatorWithKernel { class ConvTransposeOp : public framework::OperatorWithKernel {
......
...@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op."); AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op."); AddOutput("Out", "The output of cos_sim op.");
......
...@@ -18,8 +18,7 @@ namespace paddle { ...@@ -18,8 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker { class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission", AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape " "(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
"[N x D] where N is the size of the mini-batch and D is the total " "[N x D] where N is the size of the mini-batch and D is the total "
......
...@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker { class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CropOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)."); "The input should be a k-D tensor(k > 0 and k < 7).");
......
...@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D]," "(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
" where N is the batch size and D is the number of classes. " " where N is the batch size and D is the number of classes. "
......
...@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is " "(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length."); "[Lp, 1], where Lp is the sum of all input sequences' length.");
......
...@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel { ...@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CumsumOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Cumsum operator"); AddInput("X", "Input of Cumsum operator");
AddOutput("Out", "Output of Cumsum operator"); AddOutput("Out", "Output of Cumsum operator");
AddAttr<int>("axis", AddAttr<int>("axis",
......
...@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment"); AddInput("Moment", "(Tensor) Second moment");
......
...@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase { ...@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase {
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker { class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of delete op").AsDuplicable(); AddInput("X", "The input of delete op").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Delete Operator. Delete Operator.
......
...@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("DetectRes", AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: " "detections. Each row has 6 values: "
......
...@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel {
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of dropout op."); AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op."); AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
......
...@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel { ...@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyps", AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) " "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings."); "The indices for hypothesis strings.");
......
...@@ -14,26 +14,8 @@ limitations under the License. */ ...@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "Out = X + Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,26 +14,8 @@ limitations under the License. */ ...@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "Out = X / Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
ops::ElementwiseDivOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,25 +14,8 @@ limitations under the License. */ ...@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = max(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
ops::ElementwiseMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_max, elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,25 +14,8 @@ limitations under the License. */ ...@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = min(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
ops::ElementwiseMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_min, elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,27 +14,8 @@ limitations under the License. */ ...@@ -14,27 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "Out = X \\odot\\ Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y");
ops::ElementwiseMulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { ...@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() final {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op."); AddOutput("Out", "The output of elementwise op.");
...@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
"for broadcasting Y onto X.") "for broadcasting Y onto X.")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
comment_ = R"DOC( AddComment(string::Sprintf(R"DOC(
Limited Elementwise {name} Operator. Limited Elementwise %s Operator.
The equation is: The equation is:
$${equation}$$ $$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be $X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
smaller than or equal to the dimensions of $X$. smaller than or equal to the dimensions of $X$.
...@@ -100,26 +99,13 @@ For example ...@@ -100,26 +99,13 @@ For example
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$. information. However, the output only shares the LoD information with input $X$.
)DOC"; )DOC",
AddComment(comment_); GetName(), GetEquation()));
} }
protected: protected:
std::string comment_; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0;
void Replace(std::string* src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string equation) {
Replace(&comment_, "{name}", name);
Replace(&comment_, "{equation}", equation);
}
}; };
class ElementwiseOpGrad : public framework::OperatorWithKernel { class ElementwiseOpGrad : public framework::OperatorWithKernel {
...@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
...@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_pow_op.h" #include "paddle/fluid/operators/elementwise_pow_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ElementwisePowOpMaker : public ElementwiseOpMaker { class ElementwisePowOpMaker : public ElementwiseOpMaker {
public: protected:
ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker) std::string GetName() const override { return "Pow"; }
: ElementwiseOpMaker(proto, op_checker) { std::string GetEquation() const override { return "Out = X ^ Y"; }
SetComment("Pow", "Out = X ^ Y");
AddComment(comment_);
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -14,25 +14,8 @@ limitations under the License. */ ...@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "Out = X - Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_sub, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y");
ops::ElementwiseSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded."); "X is the input to be expanded.");
......
...@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
layout, library); layout, library);
} }
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) void FCOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op."); AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
......
...@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel { ...@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel {
class FCOpMaker : public framework::OpProtoAndCheckerMaker { class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FCOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
} // namespace operators } // namespace operators
......
...@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of feed op"); AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op"); AddOutput("Out", "The output of feed op");
AddAttr<int>("col", "(int) The column of feed"); AddAttr<int>("col", "(int) The column of feed");
......
...@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fetch op"); AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op"); AddOutput("Out", "The output of fetch op");
AddAttr<int>("col", "(int) The column of fetch"); AddAttr<int>("col", "(int) The column of fetch");
......
...@@ -31,8 +31,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -31,8 +31,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public: public:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
......
...@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
......
...@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase { ...@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker { class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment(R"DOC(Fill operator AddComment(R"DOC(Fill operator
Fill an tensor with `value` and `shape`. The type of the tensor is specify by Fill an tensor with `value` and `shape`. The type of the tensor is specify by
......
...@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillZerosLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fill-zeros-like op."); AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros."); AddOutput("Out", "The variable will be filled up with zeros.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -64,8 +64,7 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,7 @@ class FTRLOp : public framework::OperatorWithKernel {
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FTRLOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated."); "Input parameter value that has to be updated.");
......
...@@ -67,8 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -67,8 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GatherOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op"); AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op"); AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of gather op"); AddOutput("Out", "The output of gather op");
......
...@@ -33,8 +33,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -33,8 +33,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public: public:
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<float>("mean", AddAttr<float>("mean",
"(float, default 0.0) " "(float, default 0.0) "
"mean of random tensor.") "mean of random tensor.")
......
...@@ -70,8 +70,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -70,8 +70,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GaussianRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "Output matrix of gaussian random op"); AddOutput("Out", "Output matrix of gaussian random op");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
......
...@@ -78,8 +78,7 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -78,8 +78,7 @@ class GetPlacesOp : public framework::OperatorBase {
class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker { class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place"); AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "device count").SetDefault(0); AddAttr<int>("device_count", "device count").SetDefault(0);
AddAttr<std::string>("device_type", "device type") AddAttr<std::string>("device_type", "device type")
......
...@@ -89,8 +89,7 @@ class GoOp : public framework::OperatorBase { ...@@ -89,8 +89,7 @@ class GoOp : public framework::OperatorBase {
class GoOpMaker : public framework::OpProtoAndCheckerMaker { class GoOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GoOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kX, AddInput(kX,
"A set of variables, which are required by operators inside the " "A set of variables, which are required by operators inside the "
"block of Go Op.") "block of Go Op.")
......
...@@ -71,8 +71,7 @@ class GRUOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class GRUOp : public framework::OperatorWithKernel {
class GRUOpMaker : public framework::OpProtoAndCheckerMaker { class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GRUOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports " "(LoDTensor) The first input is a LodTensor, which supports "
"variable-time length input sequence. The underlying tensor in " "variable-time length input sequence. The underlying tensor in "
......
...@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel {
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GRUUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"input."); "input.");
......
...@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker { class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HingeLossOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits", AddInput("Logits",
"The input value (Logits) of Hinge loss op." "The input value (Logits) of Hinge loss op."
"Logits is a 2-D tensor with shape [batch_size, 1]."); "Logits is a 2-D tensor with shape [batch_size, 1].");
......
...@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input value of huber loss op." "The input value of huber loss op."
"X is a 2-D tensor with shape [batch_size, 1]."); "X is a 2-D tensor with shape [batch_size, 1].");
......
...@@ -54,8 +54,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel { ...@@ -54,8 +54,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker { class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Im2SequenceOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor has NCHW format." "(Tensor) The input tensor has NCHW format."
"N: batch size" "N: batch size"
......
...@@ -47,8 +47,7 @@ class IncrementOp : public framework::OperatorWithKernel { ...@@ -47,8 +47,7 @@ class IncrementOp : public framework::OperatorWithKernel {
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker { class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
IncrementOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor of increment operator"); AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator."); AddOutput("Out", "(Tensor) The output tensor of increment operator.");
AddAttr<float>("step", AddAttr<float>("step",
......
...@@ -42,8 +42,7 @@ class IOUSimilarityOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,7 @@ class IOUSimilarityOp : public framework::OperatorWithKernel {
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, " "Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
......
...@@ -48,8 +48,7 @@ class IsEmptyOp : public framework::OperatorBase { ...@@ -48,8 +48,7 @@ class IsEmptyOp : public framework::OperatorBase {
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker { class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
IsEmptyOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInput, "(Tensor) Tensor which is to be checked."); AddInput(kInput, "(Tensor) Tensor which is to be checked.");
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not."); AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -48,8 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel {
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker { class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
L1NormOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of l1_norm op."); AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op."); AddOutput("Out", "(Scalar) The output of l1_norm op.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -47,8 +47,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel { ...@@ -47,8 +47,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker { class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LabelSmoothOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor) The input labels of LabelSmooth operator. This " "(LoDTensor) The input labels of LabelSmooth operator. This "
"input can be batched labels in one-hot encoding or output from " "input can be batched labels in one-hot encoding or output from "
......
...@@ -61,8 +61,7 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,7 @@ class LayerNormOp : public framework::OperatorWithKernel {
class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor."); AddInput("X", "(LoDTensor) The input tensor.");
AddInput("Scale", AddInput("Scale",
"(Tensor, optional) Scale is a 1-dimensional tensor of size " "(Tensor, optional) Scale is a 1-dimensional tensor of size "
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LinearChainCRFOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission", AddInput("Emission",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"A 2-D LoDTensor with shape [N x D], where N is the size of the " "A 2-D LoDTensor with shape [N x D], where N is the size of the "
......
...@@ -343,8 +343,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -343,8 +343,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable(); AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
ListenAndServ operator ListenAndServ operator
......
...@@ -77,8 +77,7 @@ class LoadCombineOp : public framework::OperatorBase { ...@@ -77,8 +77,7 @@ class LoadCombineOp : public framework::OperatorBase {
class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoadCombineOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput( AddOutput(
"Out", "Out",
"(vector) The output LoDTensors that will be read from the input file.") "(vector) The output LoDTensors that will be read from the input file.")
......
...@@ -64,8 +64,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -64,8 +64,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The tensor need to be loaded"); AddOutput("Out", "(Tensor) The tensor need to be loaded");
AddAttr<std::string>("file_path", AddAttr<std::string>("file_path",
"(string) " "(string) "
......
...@@ -40,8 +40,7 @@ class LoDArrayLengthOp : public framework::OperatorBase { ...@@ -40,8 +40,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker { class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDArrayLengthProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensorArray) The input tensor array."); AddInput("X", "(LoDTensorArray) The input tensor array.");
AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t"); AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -38,8 +38,7 @@ class LoDRankTableOp : public framework::OperatorBase { ...@@ -38,8 +38,7 @@ class LoDRankTableOp : public framework::OperatorBase {
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDRankTableOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor) input lod tensor, must contain lod information."); "(LoDTensor) input lod tensor, must contain lod information.");
AddOutput("Out", "(LoDRankTable) The rank table of specific level."); AddOutput("Out", "(LoDRankTable) The rank table of specific level.");
......
...@@ -47,8 +47,7 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -47,8 +47,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, LoDTensor) Input variable of LoDResetOp which " "(Tensor, LoDTensor) Input variable of LoDResetOp which "
"could be a Tensor or LoDTensor, where the data of output " "could be a Tensor or LoDTensor, where the data of output "
......
...@@ -105,8 +105,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -105,8 +105,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LoDTensorToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", ""); AddInput("X", "");
AddInput("RankTable", ""); AddInput("RankTable", "");
AddOutput("Out", ""); AddOutput("Out", "");
......
...@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker { class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LogLossOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Predicted", AddInput("Predicted",
"The input value (Predicted) of Log loss op." "The input value (Predicted) of Log loss op."
"Predicted is a 2-D tensor with shape [batch_size, 1]."); "Predicted is a 2-D tensor with shape [batch_size, 1].");
......
...@@ -21,8 +21,7 @@ namespace operators { ...@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment> template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BinaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", AddInput("X",
string::Sprintf("(LoDTensor) Left hand operand of %s operator", string::Sprintf("(LoDTensor) Left hand operand of %s operator",
...@@ -45,8 +44,7 @@ Each element of Out is calculated by %s ...@@ -45,8 +44,7 @@ Each element of Out is calculated by %s
template <typename OpComment> template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
UnaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator", AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
comment.type)); comment.type));
......
...@@ -105,8 +105,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -105,8 +105,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LookupSparseTableOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W", AddInput("W",
"(SelectedRows) The input represents embedding table, " "(SelectedRows) The input represents embedding table, "
"which is a learnable parameter."); "which is a learnable parameter.");
......
...@@ -58,8 +58,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -58,8 +58,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W", AddInput("W",
"(Tensor) The input represents embedding tensors, " "(Tensor) The input represents embedding tensors, "
"which is a learnable parameter."); "which is a learnable parameter.");
......
...@@ -169,8 +169,7 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -169,8 +169,7 @@ class LRNOp : public framework::OperatorWithKernel {
template <typename T> template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker { class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LRNOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input of LRN operator. " "(Tensor) The input of LRN operator. "
"It must be a 4D tenor with NCHW format."); "It must be a 4D tenor with NCHW format.");
......
...@@ -103,8 +103,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -103,8 +103,7 @@ class LSTMOp : public framework::OperatorWithKernel {
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LSTMOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support " "(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in " "variable-time length input sequence. The underlying tensor in "
......
...@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LstmUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"Lstm unit only applies non-linear activations, please make sure" "Lstm unit only applies non-linear activations, please make sure"
"that linear tranformation has already been applied to `X`. " "that linear tranformation has already been applied to `X`. "
......
...@@ -120,8 +120,7 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -120,8 +120,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LSTMPOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LoDTensor) the input for sequence data, which supports " "(LoDTensor) the input for sequence data, which supports "
"variable-time length input sequence. The underlying tensor in " "variable-time length input sequence. The underlying tensor in "
......
...@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
template <typename T> template <typename T>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MarginRankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1", AddInput("X1",
"(2-D tensor with shape [batch_size x 1]) The score for " "(2-D tensor with shape [batch_size x 1]) The score for "
"one item X1 to be ranked, from pairwise ranking model."); "one item X1 to be ranked, from pairwise ranking model.");
......
...@@ -159,8 +159,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -159,8 +159,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of MatMul op"); AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op"); AddInput("Y", "The second input of MatMul op");
AddOutput("Out", "The output of MatMul op"); AddOutput("Out", "The output of MatMul op");
......
...@@ -41,8 +41,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase { ...@@ -41,8 +41,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker { class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxSeqenceLenOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RankTable", "The lod_rank_table."); AddInput("RankTable", "The lod_rank_table.");
AddOutput("Out", "The max sequence length."); AddOutput("Out", "The max sequence length.");
AddComment( AddComment(
......
...@@ -22,8 +22,7 @@ using framework::Tensor; ...@@ -22,8 +22,7 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxOutOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of maxout operator. " "(Tensor) The input tensor of maxout operator. "
......
...@@ -32,8 +32,7 @@ class MeanOp : public framework::OperatorWithKernel { ...@@ -32,8 +32,7 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker { class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MeanOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op"); AddOutput("Out", "The output of mean op");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -121,8 +121,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -121,8 +121,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MergeLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input LoDTensor, contains complete lod information to " "The input LoDTensor, contains complete lod information to "
"construct the output"); "construct the output");
......
...@@ -253,8 +253,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -253,8 +253,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MineHardExamplesOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"ClsLoss", "ClsLoss",
"(Tensor, default Tensor<float>), The classification loss with shape " "(Tensor, default Tensor<float>), The classification loss with shape "
......
...@@ -48,8 +48,7 @@ class MinusOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class MinusOp : public framework::OperatorWithKernel {
class MinusOpMaker : public framework::OpProtoAndCheckerMaker { class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left tensor of minus operator."); AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator."); AddInput("Y", "The right tensor of minus operator.");
AddOutput("Out", "The output tensor of minus operator."); AddOutput("Out", "The output tensor of minus operator.");
......
...@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { ...@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker { class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ModifiedHuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input tensor of modified huber loss op. " "The input tensor of modified huber loss op. "
"X is 2-D tensor with shape [batch_size, 1]."); "X is 2-D tensor with shape [batch_size, 1].");
......
...@@ -62,8 +62,7 @@ class MomentumOp : public framework::OperatorWithKernel { ...@@ -62,8 +62,7 @@ class MomentumOp : public framework::OperatorWithKernel {
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MomentumOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input parameter that has to be updated"); "Input parameter that has to be updated");
......
...@@ -96,8 +96,7 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -96,8 +96,7 @@ class MulOp : public framework::OperatorWithKernel {
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), The first input tensor of mul op."); AddInput("X", "(Tensor), The first input tensor of mul op.");
AddInput("Y", "(Tensor), The second input tensor of mul op."); AddInput("Y", "(Tensor), The second input tensor of mul op.");
AddOutput("Out", "(Tensor), The output tensor of mul op."); AddOutput("Out", "(Tensor), The output tensor of mul op.");
......
...@@ -309,8 +309,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -309,8 +309,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("BBoxes", AddInput("BBoxes",
"(Tensor) A 3-D Tensor with shape [N, M, 4] represents the " "(Tensor) A 3-D Tensor with shape [N, M, 4] represents the "
"predicted locations of M bounding bboxes, N is the batch size. " "predicted locations of M bounding bboxes, N is the batch size. "
......
...@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MultiplexOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", "The index tensor of multiplex operator."); AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.") AddInput("X", "The candidate tensors of multiplex operator.")
.AsDuplicable(); .AsDuplicable();
......
...@@ -76,8 +76,7 @@ class NCCLInitOpShapeInference : public framework::InferShapeBase { ...@@ -76,8 +76,7 @@ class NCCLInitOpShapeInference : public framework::InferShapeBase {
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParallelScopes, "The working place of parallel do."); AddInput(kParallelScopes, "The working place of parallel do.");
AddOutput("Communicator", AddOutput("Communicator",
"Create Communicator for communicating between gpus"); "Create Communicator for communicating between gpus");
...@@ -118,8 +117,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { ...@@ -118,8 +117,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
// AllReduceOp // AllReduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op"); AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op"); AddOutput("Out", "The output of AllReduce op");
...@@ -165,8 +163,7 @@ class NCCLReduceOp : public framework::OperatorWithKernel { ...@@ -165,8 +163,7 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
// ReduceOp // ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op"); AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op"); AddOutput("Out", "The output of Reduce op");
...@@ -214,8 +211,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel { ...@@ -214,8 +211,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
// BcastOp // BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLBcastOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op"); AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast"); AddOutput("Out", "The output of Bcast");
......
...@@ -75,8 +75,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,7 @@ class NCEOp : public framework::OperatorWithKernel {
class NCEOpMaker : public framework::OpProtoAndCheckerMaker { class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCEOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput( AddInput(
"Label", "Label",
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
template <typename AttrType> template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker { class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NormOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of norm operator. " "(Tensor) The input tensor of norm operator. "
......
...@@ -46,8 +46,7 @@ class OneHotOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class OneHotOp : public framework::OperatorWithKernel {
class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. " "(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index " "The last dimension of X should be 1. Each value of X is an index "
......
...@@ -48,8 +48,7 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class PadOp : public framework::OperatorWithKernel {
class PadOpMaker : public framework::OpProtoAndCheckerMaker { class PadOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PadOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)"); "The input should be a k-D tensor(k > 0 and k < 7)");
......
...@@ -196,8 +196,7 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -196,8 +196,7 @@ class ParallelDoOp : public framework::OperatorBase {
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ParallelDoOpProtoMaker(OpProto *proto, framework::OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInputs, "").AsDuplicable(); AddInput(kInputs, "").AsDuplicable();
AddInput(kParameters, "").AsDuplicable(); AddInput(kParameters, "").AsDuplicable();
AddInput(kPlaces, ""); AddInput(kPlaces, "");
......
...@@ -135,8 +135,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -135,8 +135,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
library_); library_);
} }
Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Pool2dOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
...@@ -229,8 +228,7 @@ Example: ...@@ -229,8 +228,7 @@ Example:
)DOC"); )DOC");
} }
Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Pool3dOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is " "The format of input tensor is NCDHW, where N is batch size, C is "
......
...@@ -50,12 +50,12 @@ class PoolOpGrad : public framework::OperatorWithKernel { ...@@ -50,12 +50,12 @@ class PoolOpGrad : public framework::OperatorWithKernel {
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxPool2dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
...@@ -177,8 +176,7 @@ Example: ...@@ -177,8 +176,7 @@ Example:
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxPool3dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is " "The format of input tensor is NCDHW, where N is batch size, C is "
......
...@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker { class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PositiveNegativePairOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Score", AddInput("Score",
"(Tensor, float) Model Score on an item (with " "(Tensor, float) Model Score on an item (with "
"respect to QueryID). It's a 2-D tensor with shape [batch_size, " "respect to QueryID). It's a 2-D tensor with shape [batch_size, "
......
...@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker { class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PrecisionRecallOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("MaxProbs", AddInput("MaxProbs",
"(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, " "(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the max probability " "where N is the batch size. Each row contains the max probability "
......
...@@ -64,8 +64,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -64,8 +64,7 @@ class PrefetchOp : public framework::OperatorBase {
class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PrefetchOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient", AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be" "(RPCClient) The RPC client object which will be"
......
...@@ -38,8 +38,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -38,8 +38,7 @@ class PReluOp : public framework::OperatorWithKernel {
class PReluOpMaker : public framework::OpProtoAndCheckerMaker { class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator."); AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator."); AddInput("Alpha", "The alpha weight of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator."); AddOutput("Out", "The output tensor of prelu operator.");
......
...@@ -209,8 +209,7 @@ class TensorPrintOp : public framework::OperatorBase { ...@@ -209,8 +209,7 @@ class TensorPrintOp : public framework::OperatorBase {
class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker { class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In", "Input tensor to be displayed."); AddInput("In", "Input tensor to be displayed.");
AddAttr<int>("first_n", "Only log `first_n` number of times."); AddAttr<int>("first_n", "Only log `first_n` number of times.");
AddAttr<std::string>("message", "A string message to print as a prefix."); AddAttr<std::string>("message", "A string message to print as a prefix.");
......
...@@ -79,8 +79,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -79,8 +79,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor, default Tensor<float>), " "(Tensor, default Tensor<float>), "
"the input feature data of PriorBoxOp, The layout is NCHW."); "the input feature data of PriorBoxOp, The layout is NCHW.");
......
...@@ -66,8 +66,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ProximalAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input parameter that has to be updated."); "Input parameter that has to be updated.");
......
...@@ -54,8 +54,7 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -54,8 +54,7 @@ class ProximalGDOp : public framework::OperatorWithKernel {
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ProximalGDOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated."); "Input parameter value that has to be updated.");
......
...@@ -46,8 +46,7 @@ class RankLossOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class RankLossOp : public framework::OperatorWithKernel {
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label", AddInput("Label",
"(2-D Tensor with shape [batch_size x 1]) " "(2-D Tensor with shape [batch_size x 1]) "
"The label indicating A ranked higher than B or not."); "The label indicating A ranked higher than B or not.");
......
...@@ -79,8 +79,7 @@ class ReadOp : public framework::OperatorBase { ...@@ -79,8 +79,7 @@ class ReadOp : public framework::OperatorBase {
class ReadOpMaker : public framework::OpProtoAndCheckerMaker { class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput("Reader", "(ReaderHolder) The executed reader."); AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -52,9 +52,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -52,9 +52,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}; };
class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
public: protected:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("batch_size", AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.") "How many instances the batch reader yields each time.")
.GreaterThan(0); .GreaterThan(0);
......
...@@ -113,9 +113,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -113,9 +113,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
}; };
class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
public: protected:
CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddComment(R"DOC( AddComment(R"DOC(
CreateDoubleBufferReader Operator CreateDoubleBufferReader Operator
......
...@@ -65,9 +65,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -65,9 +65,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
}; };
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase { class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
public: protected:
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0); AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
CreateMultiPassReader Operator CreateMultiPassReader Operator
......
...@@ -84,9 +84,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -84,9 +84,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
}; };
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase { class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
public: protected:
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<float>("min", "The lower bound of reader's uniform distribution."); AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution."); AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -76,9 +76,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -76,9 +76,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
}; };
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase { class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
public: protected:
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::string>("filename", "The filename of record io reader"); AddAttr<std::string>("filename", "The filename of record io reader");
AddComment(R"DOC( AddComment(R"DOC(
CreateRecordIOReader Operator CreateRecordIOReader Operator
......
...@@ -92,9 +92,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -92,9 +92,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}; };
class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase { class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
public: protected:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0); AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
CreateShuffleReader Operator CreateShuffleReader Operator
......
...@@ -53,9 +53,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase { ...@@ -53,9 +53,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
}; };
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public: protected:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddComment(R"DOC( AddComment(R"DOC(
CreateThreadedReader Operator CreateThreadedReader Operator
......
...@@ -185,9 +185,8 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -185,9 +185,8 @@ class OpenFilesOp : public framework::OperatorBase {
}; };
class OpenFilesOpMaker : public FileReaderMakerBase { class OpenFilesOpMaker : public FileReaderMakerBase {
public: protected:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read."); AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.") AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0); .GreaterThan(0);
......
...@@ -53,10 +53,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( ...@@ -53,10 +53,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
return std::unique_ptr<framework::ReaderBase>(reader); return std::unique_ptr<framework::ReaderBase>(reader);
} }
FileReaderMakerBase::FileReaderMakerBase( void FileReaderMakerBase::Make() {
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable(); AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable();
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes."); AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
...@@ -68,6 +65,7 @@ FileReaderMakerBase::FileReaderMakerBase( ...@@ -68,6 +65,7 @@ FileReaderMakerBase::FileReaderMakerBase(
"It means the reader will generate two data each time," "It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively."); "whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data."); AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
Apply();
} }
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
...@@ -127,13 +125,11 @@ void DecoratedReaderInferVarType::operator()( ...@@ -127,13 +125,11 @@ void DecoratedReaderInferVarType::operator()(
out_reader->SetDataTypes(in_reader->GetDataTypes()); out_reader->SetDataTypes(in_reader->GetDataTypes());
} }
DecoratedReaderMakerBase::DecoratedReaderMakerBase( void DecoratedReaderMakerBase::Make() {
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput("UnderlyingReader", AddInput("UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader."); "(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader."); AddOutput("Out", "(ReaderHolder) The created batch reader.");
Apply();
} }
} // namespace reader } // namespace reader
......
...@@ -47,7 +47,10 @@ extern std::vector<framework::DDim> RestoreShapes( ...@@ -47,7 +47,10 @@ extern std::vector<framework::DDim> RestoreShapes(
class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker { class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public: public:
FileReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker); void Make() final;
protected:
virtual void Apply() = 0;
}; };
class FileReaderInferShape : public framework::InferShapeBase { class FileReaderInferShape : public framework::InferShapeBase {
...@@ -76,7 +79,10 @@ class DecoratedReaderInferVarType : public framework::VarTypeInference { ...@@ -76,7 +79,10 @@ class DecoratedReaderInferVarType : public framework::VarTypeInference {
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public: public:
DecoratedReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker); void Make() final;
protected:
virtual void Apply() = 0;
}; };
} // namespace reader } // namespace reader
......
...@@ -508,8 +508,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -508,8 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker { class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RecurrentOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInputs, "rnn inputs").AsDuplicable(); AddInput(kInputs, "rnn inputs").AsDuplicable();
AddInput(kInitialStates, "rnn initial states").AsDuplicable(); AddInput(kInitialStates, "rnn initial states").AsDuplicable();
AddInput(kParameters, AddInput(kParameters,
......
...@@ -53,8 +53,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -53,8 +53,7 @@ class RecvOp : public framework::OperatorBase {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Recv operator Recv operator
......
...@@ -90,8 +90,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -90,8 +90,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() final {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are " "(Tensor) The input tensor. Tensors with rank at most 6 are "
"supported."); "supported.");
...@@ -111,78 +110,20 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -111,78 +110,20 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) " "(bool, default false) "
"If true, output a scalar reduced along all dimensions.") "If true, output a scalar reduced along all dimensions.")
.SetDefault(false); .SetDefault(false);
comment_ = R"DOC( AddComment(string::Sprintf(R"DOC(
{ReduceOp} Operator. %s Operator.
This operator computes the {reduce} of input tensor along the given dimension. This operator computes the %s of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true. The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar. If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC"; )DOC",
AddComment(comment_); GetOpType(), GetName()));
} }
protected: protected:
std::string comment_; virtual std::string GetName() const = 0;
virtual std::string GetOpType() const = 0;
void Replace(std::string *src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string op) {
Replace(&comment_, "{ReduceOp}", name);
Replace(&comment_, "{reduce}", op);
}
};
class ReduceSumOpMaker : public ReduceOpMaker {
public:
ReduceSumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceSum", "sum");
AddComment(comment_);
}
};
class ReduceMeanOpMaker : public ReduceOpMaker {
public:
ReduceMeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMean", "mean");
AddComment(comment_);
}
};
class ReduceMaxOpMaker : public ReduceOpMaker {
public:
ReduceMaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMax", "max");
AddComment(comment_);
}
};
class ReduceMinOpMaker : public ReduceOpMaker {
public:
ReduceMinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMin", "min");
AddComment(comment_);
}
};
class ReduceProdOpMaker : public ReduceOpMaker {
public:
ReduceProdOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceProd", "production");
AddComment(comment_);
}
}; };
} // namespace operators } // namespace operators
...@@ -190,25 +131,21 @@ class ReduceProdOpMaker : public ReduceOpMaker { ...@@ -190,25 +131,21 @@ class ReduceProdOpMaker : public ReduceOpMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, #define REGISTER_REDUCE_OP(op_name) \
paddle::framework::DefaultGradOpDescMaker<true>); class __##op_name##Maker__ : public ops::ReduceOpMaker { \
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp); protected: \
virtual std::string GetName() const { return #op_name; } \
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker, virtual std::string GetOpType() const { return "Reduce " #op_name; } \
paddle::framework::DefaultGradOpDescMaker<true>); }; \
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp); REGISTER_OPERATOR(reduce_##op_name, ops::ReduceOp, __##op_name##Maker__, \
paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, REGISTER_OPERATOR(reduce_##op_name##_grad, ops::ReduceGradOp)
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp); REGISTER_REDUCE_OP(sum);
REGISTER_REDUCE_OP(mean);
REGISTER_OPERATOR(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, REGISTER_REDUCE_OP(max);
paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_REDUCE_OP(min);
REGISTER_OPERATOR(reduce_min_grad, ops::ReduceGradOp); REGISTER_REDUCE_OP(prod);
REGISTER_OPERATOR(reduce_prod, ops::ReduceOp, ops::ReduceProdOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp);
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \ #define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL(reduce_type, \ REGISTER_OP_CPU_KERNEL(reduce_type, \
......
...@@ -23,9 +23,7 @@ namespace operators { ...@@ -23,9 +23,7 @@ namespace operators {
class ReorderLoDTensorByRankTableOpProtoMaker class ReorderLoDTensorByRankTableOpProtoMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto, void Make() override {
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), the input lod tensor to be reordered according to " "(LoDTensor), the input lod tensor to be reordered according to "
"Input(RankTable)."); "Input(RankTable).");
......
...@@ -22,8 +22,7 @@ namespace operators { ...@@ -22,8 +22,7 @@ namespace operators {
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor). The input tensor of reshape operator."); AddInput("X", "(Tensor). The input tensor of reshape operator.");
AddInput("Shape", AddInput("Shape",
"(Tensor<int32>, optional). If provided, reshape according to " "(Tensor<int32>, optional). If provided, reshape according to "
......
...@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel {
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RmspropOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated."); "Input parameter value that has to be updated.");
......
...@@ -59,8 +59,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { ...@@ -59,8 +59,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker { class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RNNMemoryHelperOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", ""); AddInput("X", "");
AddOutput("Out", ""); AddOutput("Out", "");
AddAttr<int>("dtype", AddAttr<int>("dtype",
...@@ -117,8 +116,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { ...@@ -117,8 +116,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
class RNNMemoryHelperGradOpInfoMaker class RNNMemoryHelperGradOpInfoMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
RNNMemoryHelperGradOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(framework::GradVarName("Out"), ""); AddInput(framework::GradVarName("Out"), "");
AddInput("X", ""); AddInput("X", "");
AddInput("Out", ""); AddInput("Out", "");
......
...@@ -98,8 +98,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -98,8 +98,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ROIPoolOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor), " "(Tensor), "
"the input of ROIPoolOp. " "the input of ROIPoolOp. "
......
...@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel { ...@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel {
class RowConvOpMaker : public framework::OpProtoAndCheckerMaker { class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
RowConvOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports " "(LoDTensor), the input(X) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor " "variable time-length input sequences. The underlying tensor "
......
...@@ -109,8 +109,7 @@ class SaveCombineOp : public framework::OperatorBase { ...@@ -109,8 +109,7 @@ class SaveCombineOp : public framework::OperatorBase {
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SaveCombineOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(vector) Input LoDTensors that need to be saved together in a file.") "(vector) Input LoDTensors that need to be saved together in a file.")
......
...@@ -117,8 +117,7 @@ class SaveOp : public framework::OperatorBase { ...@@ -117,8 +117,7 @@ class SaveOp : public framework::OperatorBase {
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor ) Input tensor to be saved"); AddInput("X", "(Tensor ) Input tensor to be saved");
AddComment(R"DOC( AddComment(R"DOC(
Save operator Save operator
......
...@@ -37,8 +37,7 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class ScaleOp : public framework::OperatorWithKernel {
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of scale operator."); AddInput("X", "(Tensor) Input tensor of scale operator.");
AddOutput("Out", "(Tensor) Output tensor of scale operator."); AddOutput("Out", "(Tensor) Output tensor of scale operator.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of scatter op"); AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated"); AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of updates op"); AddInput("Updates", "The updated value of updates op");
......
...@@ -380,8 +380,7 @@ class SelectOp : public framework::OperatorBase { ...@@ -380,8 +380,7 @@ class SelectOp : public framework::OperatorBase {
class SelectOpMaker : public framework::OpProtoAndCheckerMaker { class SelectOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SelectOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kX, AddInput(kX,
"A set of variables, which are required by operators inside the " "A set of variables, which are required by operators inside the "
"cases of Select Op") "cases of Select Op")
......
...@@ -57,8 +57,7 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -57,8 +57,7 @@ class SendBarrierOp : public framework::OperatorBase {
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SendBarrierOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("RPCClient", AddOutput("RPCClient",
"(RPCClient) The RPC client object which is" "(RPCClient) The RPC client object which is"
"initialized at most once."); "initialized at most once.");
......
...@@ -92,8 +92,7 @@ class SendOp : public framework::OperatorBase { ...@@ -92,8 +92,7 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker { class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SendOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server") AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable(); .AsDuplicable();
......
...@@ -66,8 +66,7 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -66,8 +66,7 @@ class SendVarsOp : public framework::OperatorBase {
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SendVarsOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient", AddOutput("RPCClient",
......
...@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel { ...@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceConcatOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LodTensorArray) Input is a vector of LoDTensor, " "(LodTensorArray) Input is a vector of LoDTensor, "
"each of which is a variable-length sequence or nested sequence.") "each of which is a variable-length sequence or nested sequence.")
......
...@@ -102,8 +102,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -102,8 +102,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceConvOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(LoDTensor) the input(X) is a LodTensor, which supports " "(LoDTensor) the input(X) is a LodTensor, which supports "
......
...@@ -37,8 +37,7 @@ class SequenceEraseOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class SequenceEraseOp : public framework::OperatorWithKernel {
class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(2-D LoDTensor with the 2nd dim. equal to 1) " "(2-D LoDTensor with the 2nd dim. equal to 1) "
"Input LoDTensor of SequenceEraseOp."); "Input LoDTensor of SequenceEraseOp.");
......
...@@ -94,8 +94,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -94,8 +94,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod " "(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"level is at most 1."); "level is at most 1.");
......
...@@ -38,8 +38,7 @@ class SequencePoolOp : public framework::OperatorWithKernel { ...@@ -38,8 +38,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequencePoolOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp"); AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output of SequencePoolOp does not contain LoD " "(Tensor) The output of SequencePoolOp does not contain LoD "
......
...@@ -42,8 +42,7 @@ class SequenceReshapeOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,7 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape " "(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape "
"being [N, M]."); "being [N, M].");
......
...@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { ...@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceSliceOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), " "(LoDTensor), "
"the input of SequenceSliceOp."); "the input of SequenceSliceOp.");
......
...@@ -57,8 +57,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SequenceSoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension " "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
"of length 1."); "of length 1.");
......
...@@ -68,8 +68,7 @@ class SGDOpInferVarType : public framework::VarTypeInference { ...@@ -68,8 +68,7 @@ class SGDOpInferVarType : public framework::VarTypeInference {
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor or SelectedRows) Input parameter"); AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD"); AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
......
...@@ -69,8 +69,7 @@ class ShrinkRNNMemoryOp : public ArrayOp { ...@@ -69,8 +69,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ShrinkRNNMemoryOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The RNN step memory to be shrinked."); AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN."); AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
AddInput("I", AddInput("I",
......
...@@ -86,9 +86,7 @@ class SigmoidCrossEntropyWithLogitsGradOp ...@@ -86,9 +86,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
class SigmoidCrossEntropyWithLogitsOpMaker class SigmoidCrossEntropyWithLogitsOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
SigmoidCrossEntropyWithLogitsOpMaker(OpProto* proto, void Make() override {
OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, " "(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. " "where N is the batch size and D is the number of classes. "
......
...@@ -34,8 +34,7 @@ class SignOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,7 @@ class SignOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker { class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SignOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of sign operator."); AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator."); AddOutput("Out", "(Tensor) Output tensor of sign operator.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -46,8 +46,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank at least 2. " "(Tensor, default Tensor<float>) A tensor with rank at least 2. "
"The input value of smooth l1 loss op with shape " "The input value of smooth l1 loss op with shape "
......
...@@ -77,8 +77,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -77,8 +77,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
class SoftmaxWithCrossEntropyOpMaker class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
SoftmaxWithCrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits", AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities " "(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, " "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
......
...@@ -64,8 +64,7 @@ class SplitByrefOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,7 @@ class SplitByrefOp : public framework::OperatorWithKernel {
class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker { class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitByrefOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of the split operator."); AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.") AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable(); .AsDuplicable();
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitIdsOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") AddOutput("Out", "(LoDTensor) The outputs of the input Ids.")
.AsDuplicable(); .AsDuplicable();
......
...@@ -125,8 +125,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -125,8 +125,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input LoDTensor"); AddInput("X", "The input LoDTensor");
AddInput("Mask", "A bool column vector which mask the input"); AddInput("Mask", "A bool column vector which mask the input");
AddOutput("OutTrue", "True branch of input LoDTensor"); AddOutput("OutTrue", "True branch of input LoDTensor");
......
...@@ -70,8 +70,7 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -70,8 +70,7 @@ class SplitOp : public framework::OperatorWithKernel {
class SplitOpMaker : public framework::OpProtoAndCheckerMaker { class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of the split operator."); AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.") AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable(); .AsDuplicable();
......
...@@ -19,8 +19,7 @@ namespace operators { ...@@ -19,8 +19,7 @@ namespace operators {
class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input SelectedRows."); AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable(); AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int>>("height_sections",
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker { class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SppOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of spp operator. " "(Tensor) The input tensor of spp operator. "
......
...@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SquaredL2DistanceOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input of SquaredL2DistanceOp."); AddInput("X", "(Tensor) Input of SquaredL2DistanceOp.");
AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp."); AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp.");
AddOutput("sub_result", AddOutput("sub_result",
......
...@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel {
class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker { class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SquaredL2NormOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of squared_l2_norm op."); AddInput("X", "(Tensor) The input of squared_l2_norm op.");
AddOutput("Out", "(Scalar) The output of squared_l2_norm op."); AddOutput("Out", "(Scalar) The output of squared_l2_norm op.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -112,8 +112,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -112,8 +112,7 @@ class SumOp : public framework::OperatorWithKernel {
class SumOpMaker : public framework::OpProtoAndCheckerMaker { class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator."); AddOutput("Out", "(Tensor) The output tensor of sum operator.");
......
...@@ -65,8 +65,7 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. " "(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. "
"Some elements in X will be assigned to Out based on the " "Some elements in X will be assigned to Out based on the "
......
...@@ -57,8 +57,7 @@ class WriteToArrayOp : public ArrayOp { ...@@ -57,8 +57,7 @@ class WriteToArrayOp : public ArrayOp {
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker { class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
WriteToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array"); AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput( AddInput(
"I", "I",
...@@ -148,8 +147,7 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -148,8 +147,7 @@ class ReadFromArrayOp : public ArrayOp {
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker { class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ReadFromArrayProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(TensorArray) the array will be read from."); AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I", AddInput("I",
"(Tensor) the subscript index in tensor array. The number of " "(Tensor) the subscript index in tensor array. The number of "
......
...@@ -48,8 +48,7 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,7 @@ class TopkOp : public framework::OperatorWithKernel {
class TopkOpMaker : public framework::OpProtoAndCheckerMaker { class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TopkOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
......
...@@ -56,8 +56,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,7 @@ class TransposeOp : public framework::OperatorWithKernel {
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor, tensors with rank up to 6 are supported."); "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
......
...@@ -33,8 +33,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -33,8 +33,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public: public:
UniformRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: BatchSizeLikeOpMaker(proto, op_checker) {
AddComment(R"DOC( AddComment(R"DOC(
Uniform random operator Uniform random operator
......
...@@ -85,8 +85,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -85,8 +85,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
UniformRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The output tensor of uniform random op"); AddOutput("Out", "(Tensor) The output tensor of uniform random op");
AddComment(R"DOC( AddComment(R"DOC(
Uniform random operator. Uniform random operator.
......
...@@ -20,8 +20,7 @@ namespace operators { ...@@ -20,8 +20,7 @@ namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Unpool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of unpool operator. " "(Tensor) The input tensor of unpool operator. "
......
...@@ -53,8 +53,7 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits", AddInput("Logits",
"(LodTensor, default: LoDTensor<float>), the unscaled " "(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D " "probabilities of variable-length sequences, which is a 2-D "
......
...@@ -68,8 +68,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -68,8 +68,7 @@ class WhileOp : public framework::OperatorBase {
class WhileOpMaker : public framework::OpProtoAndCheckerMaker { class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kX, AddInput(kX,
"A set of variables, which are required by operators inside the " "A set of variables, which are required by operators inside the "
"block of While Op.") "block of While Op.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册