提交 0e78cb69 编写于 作者: Y Yu Yang

Clean OpProtoAndCheckerMaker

Do not use ctor

* Reduce line of codes.
* We can use virtual function for Maker now.
* The implementation does not care what maker holds, it is easier to
refactor later.
上级 2a22da6c
......@@ -32,8 +32,7 @@ struct AddFunctor {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
......
......@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
T maker;
maker.SetProto(info->proto_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
......
......@@ -23,20 +23,21 @@ namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual void Make() = 0;
virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void SetProto(proto::OpProto* proto) { proto_ = proto; }
void SetChecker(OpAttrChecker* attr_checker) { op_checker_ = attr_checker; }
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
proto::OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
......@@ -76,16 +77,9 @@ class OpProtoAndCheckerMaker {
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
proto::OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework
} // namespace paddle
......@@ -18,9 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
......@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
......@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
......@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
......@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
......@@ -212,10 +210,7 @@ namespace framework {
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("NoGradOp, same input output. no Grad");
}
void Make() { AddComment("NoGradOp, same input output. no Grad"); }
};
class OpWithKernelTest : public OperatorWithKernel {
......@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::DeviceContext;
namespace paddle {
namespace framework {
......
......@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op");
......@@ -98,8 +97,7 @@ namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("x", "input of test op");
AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
......@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker {
public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").AsDuplicable();
......
......@@ -24,8 +24,7 @@ namespace framework {
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
......
......@@ -166,6 +166,8 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......
......@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
// TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)");
AddInput("Indices", "The the network output of topk (indices)");
......
......@@ -19,19 +19,18 @@ limitations under the License. */
namespace paddle {
namespace operators {
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
}
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
......@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
......@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
......@@ -242,8 +239,7 @@ $$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
......@@ -265,8 +261,7 @@ $$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of BRelu operator");
AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
......@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
......@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of ELU operator");
AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
......@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
......@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Pow operator");
AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
......@@ -353,8 +344,7 @@ $out = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of STanh operator");
AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
......@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
......@@ -394,8 +383,7 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
......@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
......@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
......
......@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
......
......@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
......
......@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
......
......@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor.");
......
......@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.")
......
......@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignValueOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) Output tensor of assign_value operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) "
......
......@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AucOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]."
"Each row is sorted in descending order. This input should be the"
......
......@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("param", "(Tensor), The parameter to be accumulated.");
AddInput("in_sum_1",
"(Tensor), A tensor used to store the parameter "
......
......@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
......
......@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
......
......@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids",
"(LodTensorArray)"
"score of the candidate words in each step");
......
......@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]");
......
......@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)");
......
......@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight",
......
......@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
......
......@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"PriorBox",
"(Tensor, default Tensor<float>) "
......
......@@ -21,8 +21,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type");
......
......@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kChannel,
"The Channel Variable that should be closed by"
" the ChannelClose Op.");
......
......@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput(kOutput,
"The object of a Channel type created by ChannelCreate Op.");
AddAttr<int>("capacity", "The size of the buffer of Channel.")
......
......@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase {
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelRecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"receives\" the a value sent"
"to it by a channel_send op.")
......
......@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase {
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelSendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"sends\" the passed in value to "
"a listening receiver.")
......
......@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). "
"Predictions from the network.");
......
......@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
......
......@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor)The input of clip op."
"The number of dimensions must be between [1, 9].");
......
......@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
......
......@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("axis",
......
......@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
"whole sub-block will not be executed.")
......
......@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
library);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv2DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -200,8 +199,7 @@ $$
)DOC");
}
Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv3DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......
......@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class ConvOp : public framework::OperatorWithKernel {
......
......@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension.");
......
......@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
layout_, library_);
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv2DTransposeOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator. "
......@@ -168,9 +166,7 @@ Example:
)DOC");
}
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv3DTransposeOpMaker::Make() {
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is "
......
......@@ -30,12 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class ConvTransposeOp : public framework::OperatorWithKernel {
......
......@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op.");
......
......@@ -18,8 +18,7 @@ namespace paddle {
namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
"[N x D] where N is the size of the mini-batch and D is the total "
......
......@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7).");
......
......@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
" where N is the batch size and D is the number of classes. "
......
......@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length.");
......
......@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CumsumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Cumsum operator");
AddOutput("Out", "Output of Cumsum operator");
AddAttr<int>("axis",
......
......@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
......
......@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase {
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of delete op").AsDuplicable();
AddComment(R"DOC(
Delete Operator.
......
......@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
......
......@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel {
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
......
......@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings.");
......
......@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "Out = X + Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "Out = X / Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OP_CPU_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = max(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OP_CPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = min(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OP_CPU_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,27 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "Out = X \\odot\\ Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp,
ops::ElementwiseMulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y");
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op.");
......@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
"for broadcasting Y onto X.")
.SetDefault(-1)
.EqualGreaterThan(-1);
comment_ = R"DOC(
Limited Elementwise {name} Operator.
AddComment(string::Sprintf(R"DOC(
Limited Elementwise %s Operator.
The equation is:
$${equation}$$
$$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
smaller than or equal to the dimensions of $X$.
......@@ -100,26 +99,13 @@ For example
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
)DOC";
AddComment(comment_);
)DOC",
GetName(), GetEquation()));
}
protected:
std::string comment_;
void Replace(std::string* src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string equation) {
Replace(&comment_, "{name}", name);
Replace(&comment_, "{equation}", equation);
}
virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0;
};
class ElementwiseOpGrad : public framework::OperatorWithKernel {
......@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
};
} // namespace operators
} // namespace paddle
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
......@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise_pow_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwisePowOpMaker : public ElementwiseOpMaker {
public:
ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Pow", "Out = X ^ Y");
AddComment(comment_);
}
protected:
std::string GetName() const override { return "Pow"; }
std::string GetEquation() const override { return "Out = X ^ Y"; }
};
} // namespace operators
} // namespace paddle
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "Out = X - Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_sub, ops::ElementwiseOp,
ops::ElementwiseSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y");
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
......
......@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
layout, library);
}
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
......
......@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel {
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FCOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
} // namespace operators
......
......@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
AddAttr<int>("col", "(int) The column of feed");
......
......@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op");
AddAttr<int>("col", "(int) The column of fetch");
......
......@@ -31,8 +31,7 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
void Make() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
......
......@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
......
......@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddComment(R"DOC(Fill operator
Fill an tensor with `value` and `shape`. The type of the tensor is specify by
......
......@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillZerosLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros.");
AddComment(R"DOC(
......
......@@ -64,8 +64,7 @@ class FTRLOp : public framework::OperatorWithKernel {
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FTRLOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
......
......@@ -67,8 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GatherOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of gather op");
......
......@@ -33,8 +33,7 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
void Make() override {
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
......
......@@ -70,8 +70,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GaussianRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "Output matrix of gaussian random op");
AddAttr<std::vector<int>>("shape",
......
......@@ -78,8 +78,7 @@ class GetPlacesOp : public framework::OperatorBase {
class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "device count").SetDefault(0);
AddAttr<std::string>("device_type", "device type")
......
......@@ -89,8 +89,7 @@ class GoOp : public framework::OperatorBase {
class GoOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GoOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kX,
"A set of variables, which are required by operators inside the "
"block of Go Op.")
......
......@@ -71,8 +71,7 @@ class GRUOp : public framework::OperatorWithKernel {
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports "
"variable-time length input sequence. The underlying tensor in "
......
......@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel {
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"input.");
......
......@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HingeLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Logits",
"The input value (Logits) of Hinge loss op."
"Logits is a 2-D tensor with shape [batch_size, 1].");
......
......@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input value of huber loss op."
"X is a 2-D tensor with shape [batch_size, 1].");
......
......@@ -54,8 +54,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Im2SequenceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor has NCHW format."
"N: batch size"
......
......@@ -47,8 +47,7 @@ class IncrementOp : public framework::OperatorWithKernel {
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IncrementOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator.");
AddAttr<float>("step",
......
......@@ -42,8 +42,7 @@ class IOUSimilarityOp : public framework::OperatorWithKernel {
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) "
"Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
......
......@@ -48,8 +48,7 @@ class IsEmptyOp : public framework::OperatorBase {
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
IsEmptyOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
AddComment(R"DOC(
......
......@@ -48,8 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel {
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
L1NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op.");
AddComment(R"DOC(
......
......@@ -47,8 +47,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LabelSmoothOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor) The input labels of LabelSmooth operator. This "
"input can be batched labels in one-hot encoding or output from "
......
......@@ -61,8 +61,7 @@ class LayerNormOp : public framework::OperatorWithKernel {
class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) The input tensor.");
AddInput("Scale",
"(Tensor, optional) Scale is a 1-dimensional tensor of size "
......
......@@ -19,8 +19,7 @@ namespace operators {
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LinearChainCRFOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Emission",
"(LoDTensor, default LoDTensor<float>) "
"A 2-D LoDTensor with shape [N x D], where N is the size of the "
......
......@@ -343,8 +343,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator
......
......@@ -77,8 +77,7 @@ class LoadCombineOp : public framework::OperatorBase {
class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadCombineOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput(
"Out",
"(vector) The output LoDTensors that will be read from the input file.")
......
......@@ -64,8 +64,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) The tensor need to be loaded");
AddAttr<std::string>("file_path",
"(string) "
......
......@@ -40,8 +40,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDArrayLengthProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensorArray) The input tensor array.");
AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t");
AddComment(R"DOC(
......
......@@ -38,8 +38,7 @@ class LoDRankTableOp : public framework::OperatorBase {
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDRankTableOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor) input lod tensor, must contain lod information.");
AddOutput("Out", "(LoDRankTable) The rank table of specific level.");
......
......@@ -47,8 +47,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
"could be a Tensor or LoDTensor, where the data of output "
......
......@@ -105,8 +105,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDTensorToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "");
AddInput("RankTable", "");
AddOutput("Out", "");
......
......@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Predicted",
"The input value (Predicted) of Log loss op."
"Predicted is a 2-D tensor with shape [batch_size, 1].");
......
......@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BinaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) Left hand operand of %s operator",
......@@ -45,8 +44,7 @@ Each element of Out is calculated by %s
template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
UnaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
comment.type));
......
......@@ -105,8 +105,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupSparseTableOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("W",
"(SelectedRows) The input represents embedding table, "
"which is a learnable parameter.");
......
......@@ -58,8 +58,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
......
......@@ -169,8 +169,7 @@ class LRNOp : public framework::OperatorWithKernel {
template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LRNOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input of LRN operator. "
"It must be a 4D tenor with NCHW format.");
......
......@@ -103,8 +103,7 @@ class LSTMOp : public framework::OperatorWithKernel {
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
......
......@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"Lstm unit only applies non-linear activations, please make sure"
"that linear tranformation has already been applied to `X`. "
......
......@@ -120,8 +120,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LoDTensor) the input for sequence data, which supports "
"variable-time length input sequence. The underlying tensor in "
......
......@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
template <typename T>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MarginRankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X1",
"(2-D tensor with shape [batch_size x 1]) The score for "
"one item X1 to be ranked, from pairwise ranking model.");
......
......@@ -159,8 +159,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op");
AddOutput("Out", "The output of MatMul op");
......
......@@ -41,8 +41,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxSeqenceLenOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("RankTable", "The lod_rank_table.");
AddOutput("Out", "The max sequence length.");
AddComment(
......
......@@ -22,8 +22,7 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxOutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of maxout operator. "
......
......@@ -32,8 +32,7 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MeanOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op");
AddComment(R"DOC(
......
......@@ -121,8 +121,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MergeLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input LoDTensor, contains complete lod information to "
"construct the output");
......
......@@ -253,8 +253,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MineHardExamplesOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"ClsLoss",
"(Tensor, default Tensor<float>), The classification loss with shape "
......
......@@ -48,8 +48,7 @@ class MinusOp : public framework::OperatorWithKernel {
class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator.");
AddOutput("Out", "The output tensor of minus operator.");
......
......@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ModifiedHuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input tensor of modified huber loss op. "
"X is 2-D tensor with shape [batch_size, 1].");
......
......@@ -62,8 +62,7 @@ class MomentumOp : public framework::OperatorWithKernel {
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MomentumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated");
......
......@@ -96,8 +96,7 @@ class MulOp : public framework::OperatorWithKernel {
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mul op.");
AddInput("Y", "(Tensor), The second input tensor of mul op.");
AddOutput("Out", "(Tensor), The output tensor of mul op.");
......
......@@ -309,8 +309,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("BBoxes",
"(Tensor) A 3-D Tensor with shape [N, M, 4] represents the "
"predicted locations of M bounding bboxes, N is the batch size. "
......
......@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiplexOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.")
.AsDuplicable();
......
......@@ -76,8 +76,7 @@ class NCCLInitOpShapeInference : public framework::InferShapeBase {
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kParallelScopes, "The working place of parallel do.");
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
......@@ -118,8 +117,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
// AllReduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
......@@ -165,8 +163,7 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
// ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op");
......@@ -214,8 +211,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
// BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLBcastOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast");
......
......@@ -75,8 +75,7 @@ class NCEOp : public framework::OperatorWithKernel {
class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCEOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput(
"Label",
......
......@@ -19,8 +19,7 @@ namespace operators {
template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of norm operator. "
......
......@@ -46,8 +46,7 @@ class OneHotOp : public framework::OperatorWithKernel {
class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index "
......
......@@ -48,8 +48,7 @@ class PadOp : public framework::OperatorWithKernel {
class PadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PadOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)");
......
......@@ -196,8 +196,7 @@ class ParallelDoOp : public framework::OperatorBase {
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ParallelDoOpProtoMaker(OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kInputs, "").AsDuplicable();
AddInput(kParameters, "").AsDuplicable();
AddInput(kPlaces, "");
......
......@@ -135,8 +135,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
library_);
}
Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Pool2dOpMaker::Make() {
AddInput(
"X",
"(Tensor) The input tensor of pooling operator. "
......@@ -229,8 +228,7 @@ Example:
)DOC");
}
Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Pool3dOpMaker::Make() {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is "
......
......@@ -50,12 +50,12 @@ class PoolOpGrad : public framework::OperatorWithKernel {
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
template <typename DeviceContext, typename T>
......
......@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool2dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of pooling operator. "
......@@ -177,8 +176,7 @@ Example:
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool3dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is "
......
......@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PositiveNegativePairOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Score",
"(Tensor, float) Model Score on an item (with "
"respect to QueryID). It's a 2-D tensor with shape [batch_size, "
......
......@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PrecisionRecallOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("MaxProbs",
"(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the max probability "
......
......@@ -64,8 +64,7 @@ class PrefetchOp : public framework::OperatorBase {
class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PrefetchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
......
......@@ -38,8 +38,7 @@ class PReluOp : public framework::OperatorWithKernel {
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator.");
......
......@@ -209,8 +209,7 @@ class TensorPrintOp : public framework::OperatorBase {
class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
public:
PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("In", "Input tensor to be displayed.");
AddAttr<int>("first_n", "Only log `first_n` number of times.");
AddAttr<std::string>("message", "A string message to print as a prefix.");
......
......@@ -79,8 +79,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(Tensor, default Tensor<float>), "
"the input feature data of PriorBoxOp, The layout is NCHW.");
......
......@@ -66,8 +66,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated.");
......
......@@ -54,8 +54,7 @@ class ProximalGDOp : public framework::OperatorWithKernel {
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
......
......@@ -46,8 +46,7 @@ class RankLossOp : public framework::OperatorWithKernel {
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Label",
"(2-D Tensor with shape [batch_size x 1]) "
"The label indicating A ranked higher than B or not.");
......
......@@ -79,8 +79,7 @@ class ReadOp : public framework::OperatorBase {
class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
void Make() override {
AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddComment(R"DOC(
......
......@@ -52,9 +52,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
};
class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
......
......@@ -113,14 +113,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
};
class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddComment(R"DOC(
CreateDoubleBufferReader Operator
A double buffer reader takes another reader as its 'underlying reader'.
It launches another thread to execute the 'underlying reader' asynchronously,
It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training.
)DOC");
std::unordered_set<std::string> enum_range;
......
......@@ -65,20 +65,19 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
};
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
AddComment(R"DOC(
CreateMultiPassReader Operator
This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
It takes the number of passes to run as one of its attributes
('pass_num'), and maintains a pass counter to record how many
passes it has completed. When the underlying reader reaches the
EOF, the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
('pass_num'), and maintains a pass counter to record how many
passes it has completed. When the underlying reader reaches the
EOF, the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
be re-initialized and starts a new pass automatically.
)DOC");
}
......
......@@ -84,9 +84,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
};
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
public:
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
......
......@@ -76,9 +76,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
};
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
public:
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<std::string>("filename", "The filename of record io reader");
AddComment(R"DOC(
CreateRecordIOReader Operator
......
......@@ -92,9 +92,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
};
class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC(
CreateShuffleReader Operator
......
......@@ -53,17 +53,16 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
......
......@@ -185,9 +185,8 @@ class OpenFilesOp : public framework::OperatorBase {
};
class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
......@@ -196,7 +195,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultiFileReader, which is able to
An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
......
......@@ -53,10 +53,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
return std::unique_ptr<framework::ReaderBase>(reader);
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
void FileReaderMakerBase::Make() {
AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable();
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>(
......@@ -68,6 +65,7 @@ FileReaderMakerBase::FileReaderMakerBase(
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
Apply();
}
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
......@@ -127,13 +125,11 @@ void DecoratedReaderInferVarType::operator()(
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
DecoratedReaderMakerBase::DecoratedReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
void DecoratedReaderMakerBase::Make() {
AddInput("UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
Apply();
}
} // namespace reader
......
......@@ -47,7 +47,10 @@ extern std::vector<framework::DDim> RestoreShapes(
class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public:
FileReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
void Make() final;
protected:
virtual void Apply() = 0;
};
class FileReaderInferShape : public framework::InferShapeBase {
......@@ -76,7 +79,10 @@ class DecoratedReaderInferVarType : public framework::VarTypeInference {
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public:
DecoratedReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
void Make() final;
protected:
virtual void Apply() = 0;
};
} // namespace reader
......
......@@ -508,8 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
RecurrentOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kInputs, "rnn inputs").AsDuplicable();
AddInput(kInitialStates, "rnn initial states").AsDuplicable();
AddInput(kParameters,
......
......@@ -53,8 +53,7 @@ class RecvOp : public framework::OperatorBase {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddComment(R"DOC(
Recv operator
......
......@@ -90,8 +90,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are "
"supported.");
......@@ -111,78 +110,20 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
comment_ = R"DOC(
{ReduceOp} Operator.
AddComment(string::Sprintf(R"DOC(
%s Operator.
This operator computes the {reduce} of input tensor along the given dimension.
This operator computes the %s of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC";
AddComment(comment_);
)DOC",
GetOpType(), GetName()));
}
protected:
std::string comment_;
void Replace(std::string *src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string op) {
Replace(&comment_, "{ReduceOp}", name);
Replace(&comment_, "{reduce}", op);
}
};
class ReduceSumOpMaker : public ReduceOpMaker {
public:
ReduceSumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceSum", "sum");
AddComment(comment_);
}
};
class ReduceMeanOpMaker : public ReduceOpMaker {
public:
ReduceMeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMean", "mean");
AddComment(comment_);
}
};
class ReduceMaxOpMaker : public ReduceOpMaker {
public:
ReduceMaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMax", "max");
AddComment(comment_);
}
};
class ReduceMinOpMaker : public ReduceOpMaker {
public:
ReduceMinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMin", "min");
AddComment(comment_);
}
};
class ReduceProdOpMaker : public ReduceOpMaker {
public:
ReduceProdOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceProd", "production");
AddComment(comment_);
}
virtual std::string GetName() const = 0;
virtual std::string GetOpType() const = 0;
};
} // namespace operators
......@@ -190,25 +131,21 @@ class ReduceProdOpMaker : public ReduceOpMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_min_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_prod, ops::ReduceOp, ops::ReduceProdOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp);
#define REGISTER_REDUCE_OP(op_name) \
class __##op_name##Maker__ : public ops::ReduceOpMaker { \
protected: \
virtual std::string GetName() const { return #op_name; } \
virtual std::string GetOpType() const { return "Reduce " #op_name; } \
}; \
REGISTER_OPERATOR(reduce_##op_name, ops::ReduceOp, __##op_name##Maker__, \
paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(reduce_##op_name##_grad, ops::ReduceGradOp)
REGISTER_REDUCE_OP(sum);
REGISTER_REDUCE_OP(mean);
REGISTER_REDUCE_OP(max);
REGISTER_REDUCE_OP(min);
REGISTER_REDUCE_OP(prod);
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL(reduce_type, \
......
......@@ -23,9 +23,7 @@ namespace operators {
class ReorderLoDTensorByRankTableOpProtoMaker
: public framework::OpProtoAndCheckerMaker {
public:
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), the input lod tensor to be reordered according to "
"Input(RankTable).");
......
......@@ -22,8 +22,7 @@ namespace operators {
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor). The input tensor of reshape operator.");
AddInput("Shape",
"(Tensor<int32>, optional). If provided, reshape according to "
......
......@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel {
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RmspropOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
......
......@@ -59,8 +59,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
RNNMemoryHelperOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "");
AddOutput("Out", "");
AddAttr<int>("dtype",
......@@ -117,8 +116,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
class RNNMemoryHelperGradOpInfoMaker
: public framework::OpProtoAndCheckerMaker {
public:
RNNMemoryHelperGradOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(framework::GradVarName("Out"), "");
AddInput("X", "");
AddInput("Out", "");
......
......@@ -98,8 +98,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor), "
"the input of ROIPoolOp. "
......
......@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel {
class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowConvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor "
......
......@@ -109,8 +109,7 @@ class SaveCombineOp : public framework::OperatorBase {
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveCombineOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(vector) Input LoDTensors that need to be saved together in a file.")
......
......@@ -117,8 +117,7 @@ class SaveOp : public framework::OperatorBase {
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor ) Input tensor to be saved");
AddComment(R"DOC(
Save operator
......
......@@ -37,8 +37,7 @@ class ScaleOp : public framework::OperatorWithKernel {
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddOutput("Out", "(Tensor) Output tensor of scale operator.");
AddComment(R"DOC(
......
......@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of updates op");
......
......@@ -380,8 +380,7 @@ class SelectOp : public framework::OperatorBase {
class SelectOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SelectOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kX,
"A set of variables, which are required by operators inside the "
"cases of Select Op")
......
......@@ -57,8 +57,7 @@ class SendBarrierOp : public framework::OperatorBase {
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendBarrierOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
......
......@@ -92,8 +92,7 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable();
......
......@@ -66,8 +66,7 @@ class SendVarsOp : public framework::OperatorBase {
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendVarsOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("RPCClient",
......
......@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConcatOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LodTensorArray) Input is a vector of LoDTensor, "
"each of which is a variable-length sequence or nested sequence.")
......
......@@ -102,8 +102,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(LoDTensor) the input(X) is a LodTensor, which supports "
......
......@@ -37,8 +37,7 @@ class SequenceEraseOp : public framework::OperatorWithKernel {
class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(2-D LoDTensor with the 2nd dim. equal to 1) "
"Input LoDTensor of SequenceEraseOp.");
......
......@@ -94,8 +94,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"level is at most 1.");
......
......@@ -38,8 +38,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequencePoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp");
AddOutput("Out",
"(Tensor) The output of SequencePoolOp does not contain LoD "
......
......@@ -42,8 +42,7 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape "
"being [N, M].");
......
......@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSliceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), "
"the input of SequenceSliceOp.");
......
......@@ -57,8 +57,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
"of length 1.");
......
......@@ -68,8 +68,7 @@ class SGDOpInferVarType : public framework::VarTypeInference {
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
......
......@@ -69,8 +69,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ShrinkRNNMemoryOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
AddInput("I",
......
......@@ -86,9 +86,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
class SigmoidCrossEntropyWithLogitsOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SigmoidCrossEntropyWithLogitsOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
......
......@@ -34,8 +34,7 @@ class SignOp : public framework::OperatorWithKernel {
template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator.");
AddComment(R"DOC(
......
......@@ -46,8 +46,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank at least 2. "
"The input value of smooth l1 loss op with shape "
......
......@@ -77,8 +77,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
......
......@@ -20,8 +20,7 @@ namespace operators {
class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SoftmaxWithCrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
......
......@@ -64,8 +64,7 @@ class SplitByrefOp : public framework::OperatorWithKernel {
class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitByrefOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable();
......
......@@ -19,8 +19,7 @@ namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitIdsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.")
.AsDuplicable();
......
......@@ -125,8 +125,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input LoDTensor");
AddInput("Mask", "A bool column vector which mask the input");
AddOutput("OutTrue", "True branch of input LoDTensor");
......
......@@ -70,8 +70,7 @@ class SplitOp : public framework::OperatorWithKernel {
class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable();
......
......@@ -19,8 +19,7 @@ namespace operators {
class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections",
......
......@@ -20,8 +20,7 @@ namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SppOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of spp operator. "
......
......@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2DistanceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input of SquaredL2DistanceOp.");
AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp.");
AddOutput("sub_result",
......
......@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel {
class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input of squared_l2_norm op.");
AddOutput("Out", "(Scalar) The output of squared_l2_norm op.");
AddComment(R"DOC(
......
......@@ -112,8 +112,7 @@ class SumOp : public framework::OperatorWithKernel {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator.");
......
......@@ -65,8 +65,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. "
"Some elements in X will be assigned to Out based on the "
......
......@@ -57,8 +57,7 @@ class WriteToArrayOp : public ArrayOp {
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
WriteToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput(
"I",
......@@ -148,8 +147,7 @@ class ReadFromArrayOp : public ArrayOp {
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadFromArrayProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I",
"(Tensor) the subscript index in tensor array. The number of "
......
......@@ -48,8 +48,7 @@ class TopkOp : public framework::OperatorWithKernel {
class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TopkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
......
......@@ -56,8 +56,7 @@ class TransposeOp : public framework::OperatorWithKernel {
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor, tensors with rank up to 6 are supported.");
......
......@@ -33,8 +33,7 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
UniformRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
void Make() override {
AddComment(R"DOC(
Uniform random operator
......
......@@ -85,8 +85,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
UniformRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) The output tensor of uniform random op");
AddComment(R"DOC(
Uniform random operator.
......
......@@ -20,8 +20,7 @@ namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of unpool operator. "
......
......@@ -53,8 +53,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Logits",
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
......
......@@ -68,8 +68,7 @@ class WhileOp : public framework::OperatorBase {
class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kX,
"A set of variables, which are required by operators inside the "
"block of While Op.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册