未验证 提交 046405e0 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10486 from reyoung/feature/clean_op_maker

Clean OpProtoAndCheckerMaker
......@@ -57,7 +57,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
......
......@@ -32,8 +32,7 @@ struct AddFunctor {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
......
......@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
T maker;
maker.SetProto(info->proto_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
......
......@@ -14,56 +14,57 @@ limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual void Make() = 0;
virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
CHECK(validated_) << "should call Validate after build";
}
void SetProto(proto::OpProto *proto) { proto_ = proto; }
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
proto::OpProto::Var *var_;
VariableBuilder& AsDuplicable() {
VariableBuilder &AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
VariableBuilder &AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& AsDispensable() {
VariableBuilder &AsDispensable() {
var_->set_dispensable(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddInput(const std::string &name, const std::string &comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
VariableBuilder AddOutput(const std::string &name,
const std::string &comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
TypedAttrChecker<T> &AddAttr(const std::string &name,
const std::string &comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
auto *attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
......@@ -71,21 +72,14 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
void AddComment(const std::string &comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
proto::OpProto *proto_;
OpAttrChecker *op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework
} // namespace paddle
......@@ -18,9 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op");
}
......@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of test op");
AddInput("input", "input of test op");
}
......@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
......@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
......@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
......@@ -212,10 +210,7 @@ namespace framework {
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("NoGradOp, same input output. no Grad");
}
void Make() { AddComment("NoGradOp, same input output. no Grad"); }
};
class OpWithKernelTest : public OperatorWithKernel {
......@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::DeviceContext;
namespace paddle {
namespace framework {
......
......@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("input", "input of test op");
AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op");
......@@ -98,8 +97,7 @@ namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("x", "input of test op");
AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op")
......@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker {
public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").AsDuplicable();
......
......@@ -24,8 +24,7 @@ namespace framework {
class SumOpMaker : public OpProtoAndCheckerMaker {
public:
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "").AsDuplicable();
AddOutput("Out", "");
AddComment("");
......
......@@ -166,6 +166,8 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif()
......
......@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
// TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)");
AddInput("Indices", "The the network output of topk (indices)");
......
......@@ -23,8 +23,7 @@ namespace operators {
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
void Make() override { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \
......@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
......@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
......@@ -242,8 +239,7 @@ $$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
......@@ -265,8 +261,7 @@ $$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of BRelu operator");
AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
......@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
......@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of ELU operator");
AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
......@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
......@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Pow operator");
AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
......@@ -353,8 +344,7 @@ $out = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of STanh operator");
AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
......@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
......@@ -394,8 +383,7 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
......@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
......@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
......
......@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
......
......@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
......
......@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate");
......
......@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor.");
......
......@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.")
......
......@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignValueOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) Output tensor of assign_value operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) "
......
......@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AucOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]."
"Each row is sorted in descending order. This input should be the"
......
......@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("param", "(Tensor), The parameter to be accumulated.");
AddInput("in_sum_1",
"(Tensor), A tensor used to store the parameter "
......
......@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
......
......@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
......@@ -68,7 +67,11 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
Apply();
}
protected:
virtual void Apply() = 0;
};
} // namespace operators
......
......@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids",
"(LodTensorArray)"
"score of the candidate words in each step");
......
......@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]");
......
......@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)");
......
......@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight",
......
......@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
......
......@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"PriorBox",
"(Tensor, default Tensor<float>) "
......
......@@ -21,8 +21,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type");
......
......@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kChannel,
"The Channel Variable that should be closed by"
" the ChannelClose Op.");
......
......@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput(kOutput,
"The object of a Channel type created by ChannelCreate Op.");
AddAttr<int>("capacity", "The size of the buffer of Channel.")
......
......@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase {
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelRecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"receives\" the a value sent"
"to it by a channel_send op.")
......
......@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase {
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChannelSendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"sends\" the passed in value to "
"a listening receiver.")
......
......@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). "
"Predictions from the network.");
......
......@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
......
......@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor)The input of clip op."
"The number of dimensions must be between [1, 9].");
......
......@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
......
......@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("axis",
......
......@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
"whole sub-block will not be executed.")
......
......@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
library);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv2DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -200,8 +199,7 @@ $$
)DOC");
}
Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv3DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......
......@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class ConvOp : public framework::OperatorWithKernel {
......
......@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension.");
......
......@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
layout_, library_);
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv2DTransposeOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator. "
......@@ -168,9 +166,7 @@ Example:
)DOC");
}
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Conv3DTransposeOpMaker::Make() {
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is "
......
......@@ -30,12 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class ConvTransposeOp : public framework::OperatorWithKernel {
......
......@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op.");
......
......@@ -18,8 +18,7 @@ namespace paddle {
namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
"[N x D] where N is the size of the mini-batch and D is the total "
......
......@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7).");
......
......@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
" where N is the batch size and D is the number of classes. "
......
......@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length.");
......
......@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CumsumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "Input of Cumsum operator");
AddOutput("Out", "Output of Cumsum operator");
AddAttr<int>("axis",
......
......@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
......
......@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase {
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of delete op").AsDuplicable();
AddComment(R"DOC(
Delete Operator.
......
......@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
......
......@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel {
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
......
......@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings.");
......
......@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "Out = X + Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "Out = X / Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OP_CPU_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = max(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OP_CPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = min(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OP_CPU_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,27 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "Out = X \\odot\\ Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp,
ops::ElementwiseMulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y");
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op.");
......@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
"for broadcasting Y onto X.")
.SetDefault(-1)
.EqualGreaterThan(-1);
comment_ = R"DOC(
Limited Elementwise {name} Operator.
AddComment(string::Sprintf(R"DOC(
Limited Elementwise %s Operator.
The equation is:
$${equation}$$
$$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
smaller than or equal to the dimensions of $X$.
......@@ -100,26 +99,13 @@ For example
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$.
)DOC";
AddComment(comment_);
)DOC",
GetName(), GetEquation()));
}
protected:
std::string comment_;
void Replace(std::string* src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string equation) {
Replace(&comment_, "{name}", name);
Replace(&comment_, "{equation}", equation);
}
virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0;
};
class ElementwiseOpGrad : public framework::OperatorWithKernel {
......@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
};
} // namespace operators
} // namespace paddle
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
......@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise_pow_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwisePowOpMaker : public ElementwiseOpMaker {
public:
ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Pow", "Out = X ^ Y");
AddComment(comment_);
}
protected:
std::string GetName() const override { return "Pow"; }
std::string GetEquation() const override { return "Out = X ^ Y"; }
};
} // namespace operators
} // namespace paddle
......
......@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "Out = X - Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_sub, ops::ElementwiseOp,
ops::ElementwiseSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y");
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
......
......@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
layout, library);
}
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
......
......@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel {
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FCOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
} // namespace operators
......
......@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
AddAttr<int>("col", "(int) The column of feed");
......
......@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op");
AddAttr<int>("col", "(int) The column of fetch");
......
......@@ -30,9 +30,8 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
};
class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
......
......@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
......
......@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddComment(R"DOC(Fill operator
Fill an tensor with `value` and `shape`. The type of the tensor is specify by
......
......@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillZerosLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros.");
AddComment(R"DOC(
......
......@@ -64,8 +64,7 @@ class FTRLOp : public framework::OperatorWithKernel {
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FTRLOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
......
......@@ -67,8 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GatherOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of gather op");
......
......@@ -32,9 +32,8 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
};
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
protected:
void Apply() override {
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
......
......@@ -70,8 +70,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GaussianRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "Output matrix of gaussian random op");
AddAttr<std::vector<int>>("shape",
......
......@@ -78,8 +78,7 @@ class GetPlacesOp : public framework::OperatorBase {
class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "device count").SetDefault(0);
AddAttr<std::string>("device_type", "device type")
......
......@@ -89,8 +89,7 @@ class GoOp : public framework::OperatorBase {
class GoOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GoOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kX,
"A set of variables, which are required by operators inside the "
"block of Go Op.")
......
......@@ -71,8 +71,7 @@ class GRUOp : public framework::OperatorWithKernel {
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports "
"variable-time length input sequence. The underlying tensor in "
......
......@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel {
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"input.");
......
......@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HingeLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Logits",
"The input value (Logits) of Hinge loss op."
"Logits is a 2-D tensor with shape [batch_size, 1].");
......
......@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input value of huber loss op."
"X is a 2-D tensor with shape [batch_size, 1].");
......
......@@ -54,8 +54,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Im2SequenceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor has NCHW format."
"N: batch size"
......
......@@ -47,8 +47,7 @@ class IncrementOp : public framework::OperatorWithKernel {
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IncrementOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator.");
AddAttr<float>("step",
......
......@@ -42,8 +42,7 @@ class IOUSimilarityOp : public framework::OperatorWithKernel {
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) "
"Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
......
......@@ -48,8 +48,7 @@ class IsEmptyOp : public framework::OperatorBase {
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
IsEmptyOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
AddComment(R"DOC(
......
......@@ -48,8 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel {
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
L1NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op.");
AddComment(R"DOC(
......
......@@ -47,8 +47,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LabelSmoothOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor) The input labels of LabelSmooth operator. This "
"input can be batched labels in one-hot encoding or output from "
......
......@@ -61,8 +61,7 @@ class LayerNormOp : public framework::OperatorWithKernel {
class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) The input tensor.");
AddInput("Scale",
"(Tensor, optional) Scale is a 1-dimensional tensor of size "
......
......@@ -19,8 +19,7 @@ namespace operators {
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LinearChainCRFOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Emission",
"(LoDTensor, default LoDTensor<float>) "
"A 2-D LoDTensor with shape [N x D], where N is the size of the "
......
......@@ -343,8 +343,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator
......
......@@ -77,8 +77,7 @@ class LoadCombineOp : public framework::OperatorBase {
class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadCombineOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput(
"Out",
"(vector) The output LoDTensors that will be read from the input file.")
......
......@@ -73,8 +73,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) The tensor need to be loaded");
AddAttr<bool>(
"load_as_fp16",
......
......@@ -40,8 +40,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDArrayLengthProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensorArray) The input tensor array.");
AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t");
AddComment(R"DOC(
......
......@@ -38,8 +38,7 @@ class LoDRankTableOp : public framework::OperatorBase {
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDRankTableOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor) input lod tensor, must contain lod information.");
AddOutput("Out", "(LoDRankTable) The rank table of specific level.");
......
......@@ -47,8 +47,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, LoDTensor) Input variable of LoDResetOp which "
"could be a Tensor or LoDTensor, where the data of output "
......
......@@ -105,8 +105,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDTensorToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "");
AddInput("RankTable", "");
AddOutput("Out", "");
......
......@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Predicted",
"The input value (Predicted) of Log loss op."
"Predicted is a 2-D tensor with shape [batch_size, 1].");
......
......@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BinaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf("(LoDTensor) Left hand operand of %s operator",
......@@ -45,8 +44,7 @@ Each element of Out is calculated by %s
template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
UnaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
comment.type));
......
......@@ -105,8 +105,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupSparseTableOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("W",
"(SelectedRows) The input represents embedding table, "
"which is a learnable parameter.");
......
......@@ -58,8 +58,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
......
......@@ -169,8 +169,7 @@ class LRNOp : public framework::OperatorWithKernel {
template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LRNOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input of LRN operator. "
"It must be a 4D tenor with NCHW format.");
......
......@@ -103,8 +103,7 @@ class LSTMOp : public framework::OperatorWithKernel {
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
......
......@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"Lstm unit only applies non-linear activations, please make sure"
"that linear tranformation has already been applied to `X`. "
......
......@@ -120,8 +120,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(LoDTensor) the input for sequence data, which supports "
"variable-time length input sequence. The underlying tensor in "
......
......@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
template <typename T>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MarginRankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X1",
"(2-D tensor with shape [batch_size x 1]) The score for "
"one item X1 to be ranked, from pairwise ranking model.");
......
......@@ -322,8 +322,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op");
AddOutput("Out", "The output of MatMul op");
......
......@@ -41,8 +41,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxSeqenceLenOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("RankTable", "The lod_rank_table.");
AddOutput("Out", "The max sequence length.");
AddComment(
......
......@@ -22,8 +22,7 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxOutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of maxout operator. "
......
......@@ -32,8 +32,7 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MeanOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op");
AddComment(R"DOC(
......
......@@ -121,8 +121,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MergeLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input LoDTensor, contains complete lod information to "
"construct the output");
......
......@@ -253,8 +253,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MineHardExamplesOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"ClsLoss",
"(Tensor, default Tensor<float>), The classification loss with shape "
......
......@@ -48,8 +48,7 @@ class MinusOp : public framework::OperatorWithKernel {
class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator.");
AddOutput("Out", "The output tensor of minus operator.");
......
......@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ModifiedHuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input tensor of modified huber loss op. "
"X is 2-D tensor with shape [batch_size, 1].");
......
......@@ -62,8 +62,7 @@ class MomentumOp : public framework::OperatorWithKernel {
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MomentumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated");
......
......@@ -96,8 +96,7 @@ class MulOp : public framework::OperatorWithKernel {
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mul op.");
AddInput("Y", "(Tensor), The second input tensor of mul op.");
AddOutput("Out", "(Tensor), The output tensor of mul op.");
......
......@@ -309,8 +309,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("BBoxes",
"(Tensor) A 3-D Tensor with shape [N, M, 4] represents the "
"predicted locations of M bounding bboxes, N is the batch size. "
......
......@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiplexOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.")
.AsDuplicable();
......
......@@ -76,8 +76,7 @@ class NCCLInitOpShapeInference : public framework::InferShapeBase {
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kParallelScopes, "The working place of parallel do.");
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
......@@ -118,8 +117,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
// AllReduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
......@@ -165,8 +163,7 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
// ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op");
......@@ -214,8 +211,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
// BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLBcastOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast");
......
......@@ -75,8 +75,7 @@ class NCEOp : public framework::OperatorWithKernel {
class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCEOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput(
"Label",
......
......@@ -19,8 +19,7 @@ namespace operators {
template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of norm operator. "
......
......@@ -46,8 +46,7 @@ class OneHotOp : public framework::OperatorWithKernel {
class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index "
......
......@@ -48,8 +48,7 @@ class PadOp : public framework::OperatorWithKernel {
class PadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PadOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)");
......
......@@ -196,8 +196,7 @@ class ParallelDoOp : public framework::OperatorBase {
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ParallelDoOpProtoMaker(OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kInputs, "").AsDuplicable();
AddInput(kParameters, "").AsDuplicable();
AddInput(kPlaces, "");
......
......@@ -135,8 +135,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
library_);
}
Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Pool2dOpMaker::Make() {
AddInput(
"X",
"(Tensor) The input tensor of pooling operator. "
......@@ -229,8 +228,7 @@ Example:
)DOC");
}
Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Pool3dOpMaker::Make() {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is "
......
......@@ -50,12 +50,12 @@ class PoolOpGrad : public framework::OperatorWithKernel {
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
void Make() override;
};
template <typename DeviceContext, typename T>
......
......@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool2dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of pooling operator. "
......@@ -177,8 +176,7 @@ Example:
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool3dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is "
......
......@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PositiveNegativePairOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Score",
"(Tensor, float) Model Score on an item (with "
"respect to QueryID). It's a 2-D tensor with shape [batch_size, "
......
......@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PrecisionRecallOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("MaxProbs",
"(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
"where N is the batch size. Each row contains the max probability "
......
......@@ -64,8 +64,7 @@ class PrefetchOp : public framework::OperatorBase {
class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PrefetchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
......
......@@ -38,8 +38,7 @@ class PReluOp : public framework::OperatorWithKernel {
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator.");
......
......@@ -209,8 +209,7 @@ class TensorPrintOp : public framework::OperatorBase {
class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
public:
PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("In", "Input tensor to be displayed.");
AddAttr<int>("first_n", "Only log `first_n` number of times.");
AddAttr<std::string>("message", "A string message to print as a prefix.");
......
......@@ -79,8 +79,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Input",
"(Tensor, default Tensor<float>), "
"the input feature data of PriorBoxOp, The layout is NCHW.");
......
......@@ -66,8 +66,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter that has to be updated.");
......
......@@ -54,8 +54,7 @@ class ProximalGDOp : public framework::OperatorWithKernel {
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
......
......@@ -46,8 +46,7 @@ class RankLossOp : public framework::OperatorWithKernel {
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Label",
"(2-D Tensor with shape [batch_size x 1]) "
"The label indicating A ranked higher than B or not.");
......
......@@ -79,8 +79,7 @@ class ReadOp : public framework::OperatorBase {
class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
void Make() override {
AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddComment(R"DOC(
......
......@@ -52,9 +52,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
};
class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
......
......@@ -113,9 +113,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
};
class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateDoubleBufferReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddComment(R"DOC(
CreateDoubleBufferReader Operator
......
......@@ -65,9 +65,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
};
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
AddComment(R"DOC(
CreateMultiPassReader Operator
......
......@@ -84,9 +84,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
};
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
public:
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
......
......@@ -76,9 +76,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
};
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
public:
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<std::string>("filename", "The filename of record io reader");
AddComment(R"DOC(
CreateRecordIOReader Operator
......
......@@ -92,9 +92,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
};
class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC(
CreateShuffleReader Operator
......
......@@ -53,9 +53,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddComment(R"DOC(
CreateThreadedReader Operator
......
......@@ -185,9 +185,8 @@ class OpenFilesOp : public framework::OperatorBase {
};
class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
protected:
void Apply() override {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
......
......@@ -53,10 +53,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
return std::unique_ptr<framework::ReaderBase>(reader);
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
void FileReaderMakerBase::Make() {
AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable();
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>(
......@@ -68,6 +65,7 @@ FileReaderMakerBase::FileReaderMakerBase(
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
Apply();
}
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
......@@ -127,13 +125,11 @@ void DecoratedReaderInferVarType::operator()(
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
DecoratedReaderMakerBase::DecoratedReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
void DecoratedReaderMakerBase::Make() {
AddInput("UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
Apply();
}
} // namespace reader
......
......@@ -47,7 +47,10 @@ extern std::vector<framework::DDim> RestoreShapes(
class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public:
FileReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
void Make() final;
protected:
virtual void Apply() = 0;
};
class FileReaderInferShape : public framework::InferShapeBase {
......@@ -76,7 +79,10 @@ class DecoratedReaderInferVarType : public framework::VarTypeInference {
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public:
DecoratedReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
void Make() final;
protected:
virtual void Apply() = 0;
};
} // namespace reader
......
......@@ -508,8 +508,7 @@ class RecurrentGradOp : public RecurrentBase {
class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
RecurrentOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kInputs, "rnn inputs").AsDuplicable();
AddInput(kInitialStates, "rnn initial states").AsDuplicable();
AddInput(kParameters,
......
......@@ -53,8 +53,7 @@ class RecvOp : public framework::OperatorBase {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddComment(R"DOC(
Recv operator
......
......@@ -90,8 +90,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() final {
AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are "
"supported.");
......@@ -111,78 +110,20 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) "
"If true, output a scalar reduced along all dimensions.")
.SetDefault(false);
comment_ = R"DOC(
{ReduceOp} Operator.
AddComment(string::Sprintf(R"DOC(
%s Operator.
This operator computes the {reduce} of input tensor along the given dimension.
This operator computes the %s of input tensor along the given dimension.
The result tensor has 1 fewer dimension than the input unless keep_dim is true.
If reduce_all is true, just reduce along all dimensions and output a scalar.
)DOC";
AddComment(comment_);
)DOC",
GetOpType(), GetName()));
}
protected:
std::string comment_;
void Replace(std::string *src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string op) {
Replace(&comment_, "{ReduceOp}", name);
Replace(&comment_, "{reduce}", op);
}
};
class ReduceSumOpMaker : public ReduceOpMaker {
public:
ReduceSumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceSum", "sum");
AddComment(comment_);
}
};
class ReduceMeanOpMaker : public ReduceOpMaker {
public:
ReduceMeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMean", "mean");
AddComment(comment_);
}
};
class ReduceMaxOpMaker : public ReduceOpMaker {
public:
ReduceMaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMax", "max");
AddComment(comment_);
}
};
class ReduceMinOpMaker : public ReduceOpMaker {
public:
ReduceMinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMin", "min");
AddComment(comment_);
}
};
class ReduceProdOpMaker : public ReduceOpMaker {
public:
ReduceProdOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceProd", "production");
AddComment(comment_);
}
virtual std::string GetName() const = 0;
virtual std::string GetOpType() const = 0;
};
} // namespace operators
......@@ -190,25 +131,21 @@ class ReduceProdOpMaker : public ReduceOpMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_min_grad, ops::ReduceGradOp);
REGISTER_OPERATOR(reduce_prod, ops::ReduceOp, ops::ReduceProdOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reduce_prod_grad, ops::ReduceGradOp);
#define REGISTER_REDUCE_OP(op_name) \
class __##op_name##Maker__ : public ops::ReduceOpMaker { \
protected: \
virtual std::string GetName() const { return #op_name; } \
virtual std::string GetOpType() const { return "Reduce " #op_name; } \
}; \
REGISTER_OPERATOR(reduce_##op_name, ops::ReduceOp, __##op_name##Maker__, \
paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(reduce_##op_name##_grad, ops::ReduceGradOp)
REGISTER_REDUCE_OP(sum);
REGISTER_REDUCE_OP(mean);
REGISTER_REDUCE_OP(max);
REGISTER_REDUCE_OP(min);
REGISTER_REDUCE_OP(prod);
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL(reduce_type, \
......
......@@ -23,9 +23,7 @@ namespace operators {
class ReorderLoDTensorByRankTableOpProtoMaker
: public framework::OpProtoAndCheckerMaker {
public:
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), the input lod tensor to be reordered according to "
"Input(RankTable).");
......
......@@ -22,8 +22,7 @@ namespace operators {
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor). The input tensor of reshape operator.");
AddInput("Shape",
"(Tensor<int32>, optional). If provided, reshape according to "
......
......@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel {
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RmspropOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated.");
......
......@@ -59,8 +59,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
RNNMemoryHelperOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "");
AddOutput("Out", "");
AddAttr<int>("dtype",
......@@ -117,8 +116,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
class RNNMemoryHelperGradOpInfoMaker
: public framework::OpProtoAndCheckerMaker {
public:
RNNMemoryHelperGradOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(framework::GradVarName("Out"), "");
AddInput("X", "");
AddInput("Out", "");
......
......@@ -98,8 +98,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor), "
"the input of ROIPoolOp. "
......
......@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel {
class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowConvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor "
......
......@@ -127,8 +127,7 @@ class SaveCombineOp : public framework::OperatorBase {
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveCombineOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(vector) Input LoDTensors that need to be saved together in a file.")
......
......@@ -117,8 +117,7 @@ class SaveOp : public framework::OperatorBase {
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor ) Input tensor to be saved");
AddComment(R"DOC(
Save operator
......
......@@ -37,8 +37,7 @@ class ScaleOp : public framework::OperatorWithKernel {
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddOutput("Out", "(Tensor) Output tensor of scale operator.");
AddComment(R"DOC(
......
......@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of updates op");
......
......@@ -380,8 +380,7 @@ class SelectOp : public framework::OperatorBase {
class SelectOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SelectOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kX,
"A set of variables, which are required by operators inside the "
"cases of Select Op")
......
......@@ -57,8 +57,7 @@ class SendBarrierOp : public framework::OperatorBase {
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendBarrierOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
......
......@@ -92,8 +92,7 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable();
......
......@@ -66,8 +66,7 @@ class SendVarsOp : public framework::OperatorBase {
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendVarsOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("RPCClient",
......
......@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConcatOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LodTensorArray) Input is a vector of LoDTensor, "
"each of which is a variable-length sequence or nested sequence.")
......
......@@ -102,8 +102,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(LoDTensor) the input(X) is a LodTensor, which supports "
......
......@@ -37,8 +37,7 @@ class SequenceEraseOp : public framework::OperatorWithKernel {
class SequenceEraseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceEraseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(2-D LoDTensor with the 2nd dim. equal to 1) "
"Input LoDTensor of SequenceEraseOp.");
......
......@@ -94,8 +94,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
"level is at most 1.");
......
......@@ -38,8 +38,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequencePoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp");
AddOutput("Out",
"(Tensor) The output of SequencePoolOp does not contain LoD "
......
......@@ -42,8 +42,7 @@ class SequenceReshapeOp : public framework::OperatorWithKernel {
class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor with shape "
"being [N, M].");
......
......@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSliceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), "
"the input of SequenceSliceOp.");
......
......@@ -57,8 +57,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
"of length 1.");
......
......@@ -68,8 +68,7 @@ class SGDOpInferVarType : public framework::VarTypeInference {
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
......
......@@ -69,8 +69,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ShrinkRNNMemoryOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
AddInput("I",
......
......@@ -86,9 +86,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
class SigmoidCrossEntropyWithLogitsOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SigmoidCrossEntropyWithLogitsOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
......
......@@ -34,8 +34,7 @@ class SignOp : public framework::OperatorWithKernel {
template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator.");
AddComment(R"DOC(
......
......@@ -46,8 +46,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank at least 2. "
"The input value of smooth l1 loss op with shape "
......
......@@ -77,8 +77,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
......
......@@ -20,8 +20,7 @@ namespace operators {
class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SoftmaxWithCrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
......
......@@ -64,8 +64,7 @@ class SplitByrefOp : public framework::OperatorWithKernel {
class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitByrefOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable();
......
......@@ -19,8 +19,7 @@ namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitIdsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.")
.AsDuplicable();
......
......@@ -125,8 +125,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input LoDTensor");
AddInput("Mask", "A bool column vector which mask the input");
AddOutput("OutTrue", "True branch of input LoDTensor");
......
......@@ -70,8 +70,7 @@ class SplitOp : public framework::OperatorWithKernel {
class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
.AsDuplicable();
......
......@@ -19,8 +19,7 @@ namespace operators {
class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections",
......
......@@ -20,8 +20,7 @@ namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SppOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of spp operator. "
......
......@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2DistanceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) Input of SquaredL2DistanceOp.");
AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp.");
AddOutput("sub_result",
......
......@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel {
class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input of squared_l2_norm op.");
AddOutput("Out", "(Scalar) The output of squared_l2_norm op.");
AddComment(R"DOC(
......
......@@ -112,8 +112,7 @@ class SumOp : public framework::OperatorWithKernel {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator.");
......
......@@ -65,8 +65,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X",
"(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. "
"Some elements in X will be assigned to Out based on the "
......
......@@ -57,8 +57,7 @@ class WriteToArrayOp : public ArrayOp {
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
WriteToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput(
"I",
......@@ -148,8 +147,7 @@ class ReadFromArrayOp : public ArrayOp {
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadFromArrayProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I",
"(Tensor) the subscript index in tensor array. The number of "
......
......@@ -48,8 +48,7 @@ class TopkOp : public framework::OperatorWithKernel {
class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TopkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
......
......@@ -56,8 +56,7 @@ class TransposeOp : public framework::OperatorWithKernel {
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor, tensors with rank up to 6 are supported.");
......
......@@ -32,9 +32,8 @@ class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
};
class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
UniformRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
protected:
void Apply() override {
AddComment(R"DOC(
Uniform random operator
......
......@@ -85,8 +85,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
UniformRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddOutput("Out", "(Tensor) The output tensor of uniform random op");
AddComment(R"DOC(
Uniform random operator.
......
......@@ -20,8 +20,7 @@ namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of unpool operator. "
......
......@@ -53,8 +53,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput("Logits",
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
......
......@@ -68,8 +68,7 @@ class WhileOp : public framework::OperatorBase {
class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
void Make() override {
AddInput(kX,
"A set of variables, which are required by operators inside the "
"block of While Op.")
......
......@@ -113,7 +113,7 @@ def generate_layer_fn(op_type):
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated.")
"automatically generated. {0}".format(op_type))
if not_intermediate_outputs[0].duplicable:
raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册