提交 64305b3f 编写于 作者: Y yuyang18

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/polish_visit_data_type
...@@ -62,9 +62,9 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl ...@@ -62,9 +62,9 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl
## Installation ## Installation
It is recommended to check out the It is recommended to check out the
[Docker installation guide](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/docker_install_en.html) [Docker installation guide](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/docker_install_en.html)
before looking into the before looking into the
[build from source guide](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/build_from_source_en.html). [build from source guide](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/build_from_source_en.html).
## Documentation ## Documentation
......
...@@ -45,9 +45,9 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -45,9 +45,9 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-unused-result")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} -Wno-error=strict-overflow") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} -Wno-error=strict-overflow") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
ExternalProject_Add( ExternalProject_Add(
${MKLDNN_PROJECT} ${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
...@@ -61,6 +61,7 @@ ExternalProject_Add( ...@@ -61,6 +61,7 @@ ExternalProject_Add(
CMAKE_ARGS -DMKLROOT=${MKLML_ROOT} CMAKE_ARGS -DMKLROOT=${MKLML_ROOT}
CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG} CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG} CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
CMAKE_ARGS -DWITH_TEST=OFF -DWITH_EXAMPLE=OFF
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR}
-DMKLROOT:PATH=${MKLML_ROOT} -DMKLROOT:PATH=${MKLML_ROOT}
) )
......
...@@ -27,7 +27,7 @@ ENDIF() ...@@ -27,7 +27,7 @@ ENDIF()
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(MKLML_PROJECT "extern_mklml") SET(MKLML_PROJECT "extern_mklml")
SET(MKLML_VER "mklml_lnx_2018.0.1.20171007") SET(MKLML_VER "mklml_lnx_2018.0.3.20180406")
SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz") SET(MKLML_URL "http://paddlepaddledeps.bj.bcebos.com/${MKLML_VER}.tgz")
SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml") SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml")
SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}")
......
...@@ -155,7 +155,7 @@ into offsets ...@@ -155,7 +155,7 @@ into offsets
3 2+3 4+5 1+9 2+10 3+12 3 2+3 4+5 1+9 2+10 3+12
``` ```
so we know that the first sentence is from word 0 to word 3, and the second sentence from work 3 to word 5. so we know that the first sentence is from word 0 to word 3, and the second sentence from word 3 to word 5.
Similarly, the lengths in the top level LoD Similarly, the lengths in the top level LoD
......
...@@ -57,7 +57,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor ...@@ -57,7 +57,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost) cc_library(attribute SRCS attribute.cc DEPS framework_proto boost)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context) device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
......
...@@ -32,8 +32,7 @@ struct AddFunctor { ...@@ -32,8 +32,7 @@ struct AddFunctor {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op"); AddInput("input", "input1 of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false); AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
......
...@@ -36,7 +36,7 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,7 +36,7 @@ struct ComputationOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
virtual bool NeedWait(VarHandleBase *in_var); bool NeedWait(VarHandleBase *in_var) override;
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
......
...@@ -42,7 +42,7 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -42,7 +42,7 @@ struct FetchOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
virtual void WaitInputVarGenerated(const platform::Place &place); void WaitInputVarGenerated(const platform::Place &place) override;
private: private:
FeedFetchList *data_; FeedFetchList *data_;
......
...@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale) platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs),
balance_parameter_opt_between_cards_(
balance_parameter_opt_between_cards) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale) const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes) { local_scopes_(local_scopes),
balance_parameter_opt_between_cards_(
balance_parameter_opt_between_cards) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
...@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// Find "send" op first for split is in front of send. // Find "send" op first for split is in front of send.
OpDesc *send_op = GetSendOpDesc(program); OpDesc *send_op = GetSendOpDesc(program);
size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (op->Type() == "send") {
...@@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
is_forwarding = false; is_forwarding = false;
} else { } else {
CreateComputationalOps(&result, *op, places_.size()); int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name);
}
}
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (IsSparseGradient(var_types, og)) { if (balance_parameter_opt_between_cards_) {
CreateReduceOp(&result, og, 0); CreateReduceOp(&result, og, cur_device_id);
CreateBroadcastOp(&result, og, 0); var_name_on_devices[cur_device_id].emplace(og);
bcast_var_name_set[cur_device_id].emplace(
og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_device_id = (cur_device_id + 1) % places_.size();
} else { } else {
InsertNCCLAllReduceOp(&result, og); if (IsSparseGradient(var_types, og)) {
CreateReduceOp(&result, og, 0);
CreateBroadcastOp(&result, og, 0);
} else {
InsertNCCLAllReduceOp(&result, og);
}
} }
} }
} }
...@@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
} }
// Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
/* /*
Dependency graph has been constructed. However, there are still data Dependency graph has been constructed. However, there are still data
harzaeds need to be handled. harzaeds need to be handled.
...@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( ...@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once; return is_pg_once;
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
if (!balance_parameter_opt_between_cards_) {
return -1;
}
int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
}
}
return var_dev_id;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
......
...@@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -36,13 +36,15 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, platform::NCCLContextMap *nccl_ctxs,
bool use_default_grad_scale); bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
bool use_default_grad_scale); bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -60,6 +62,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
bool balance_parameter_opt_between_cards_;
bool use_default_grad_scale_; bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
...@@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -84,6 +87,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og, const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
......
...@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new proto::OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_); T maker;
maker.SetProto(info->proto_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate(); maker.Validate();
info->proto_->set_type(op_type); info->proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -14,56 +14,57 @@ limitations under the License. */ ...@@ -14,56 +14,57 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
using OpProto = proto::OpProto; virtual void Make() = 0;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
virtual ~OpProtoAndCheckerMaker() { virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build"); CHECK(validated_) << "should call Validate after build";
} }
void SetProto(proto::OpProto *proto) { proto_ = proto; }
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
void Validate(); void Validate();
protected: protected:
struct VariableBuilder { struct VariableBuilder {
OpProto::Var* var_; proto::OpProto::Var *var_;
VariableBuilder& AsDuplicable() { VariableBuilder &AsDuplicable() {
var_->set_duplicable(true); var_->set_duplicable(true);
return *this; return *this;
} }
VariableBuilder& AsIntermediate() { VariableBuilder &AsIntermediate() {
var_->set_intermediate(true); var_->set_intermediate(true);
return *this; return *this;
} }
VariableBuilder& AsDispensable() { VariableBuilder &AsDispensable() {
var_->set_dispensable(true); var_->set_dispensable(true);
return *this; return *this;
} }
}; };
VariableBuilder AddInput(const std::string& name, const std::string& comment); VariableBuilder AddInput(const std::string &name, const std::string &comment);
VariableBuilder AddOutput(const std::string& name, VariableBuilder AddOutput(const std::string &name,
const std::string& comment); const std::string &comment);
template <typename T> template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name, TypedAttrChecker<T> &AddAttr(const std::string &name,
const std::string& comment, const std::string &comment,
bool generated = false) { bool generated = false) {
auto* attr = proto_->add_attrs(); auto *attr = proto_->add_attrs();
attr->set_name(name); attr->set_name(name);
attr->set_comment(comment); attr->set_comment(comment);
attr->set_generated(generated); attr->set_generated(generated);
...@@ -71,21 +72,14 @@ class OpProtoAndCheckerMaker { ...@@ -71,21 +72,14 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name); return op_checker_->AddAttrChecker<T>(name);
} }
void AddComment(const std::string& comment) { proto_->set_comment(comment); } void AddComment(const std::string &comment) { proto_->set_comment(comment); }
private: private:
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
OpProto* proto_; proto::OpProto *proto_;
OpAttrChecker* op_checker_; OpAttrChecker *op_checker_;
bool validated_{false}; bool validated_{false};
}; };
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,9 +18,7 @@ limitations under the License. */ ...@@ -18,9 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto, void Make() {
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op"); AddAttr<float>("scale", "scale of test op");
AddAttr<float>("scale", "scale of test op"); AddAttr<float>("scale", "scale of test op");
} }
...@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedAttr) { TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public: public:
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto, void Make() {
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddInput("input", "input of test op"); AddInput("input", "input of test op");
} }
...@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
TEST(ProtoMaker, DuplicatedInOut) { TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto);
proto_maker.SetChecker(&op_checker);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase { ...@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op"); AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op"); AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
...@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase { ...@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").AsDuplicable(); AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate(); AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) { auto my_checker = [](int i) {
...@@ -212,10 +210,7 @@ namespace framework { ...@@ -212,10 +210,7 @@ namespace framework {
class OpKernelTestMaker : public OpProtoAndCheckerMaker { class OpKernelTestMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() { AddComment("NoGradOp, same input output. no Grad"); }
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("NoGradOp, same input output. no Grad");
}
}; };
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
...@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) { ...@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
static int op_test_value = 0; static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext; using paddle::platform::CUDADeviceContext;
using paddle::platform::DeviceContext;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker { class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op"); AddAttr<float>("scale", "scale of cosine op");
...@@ -98,8 +97,7 @@ namespace framework { ...@@ -98,8 +97,7 @@ namespace framework {
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "input of test op"); AddInput("x", "input of test op");
AddOutput("y", "output of test op"); AddOutput("y", "output of test op");
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
...@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> { ...@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
class OpKernelTestMultiInputsProtoAndCheckerMaker class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker { : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, void Make() {
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("xs", "inputs of test op").AsDuplicable(); AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op"); AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").AsDuplicable(); AddOutput("ys", "outputs of test op").AsDuplicable();
......
...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -58,7 +58,7 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay, Scope *scope, const std::vector<Scope *> &local_scopes, bool allow_op_delay,
bool use_default_grad_scale) bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope; member_->global_scope_ = scope;
...@@ -93,11 +93,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -93,11 +93,12 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
member_->nccl_ctxs_.get(), use_default_grad_scale); member_->nccl_ctxs_.get(), use_default_grad_scale,
balance_parameter_opt_between_cards);
#else #else
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(
params, member_->local_scopes_, member_->places_, loss_var_name, params, member_->local_scopes_,
use_default_grad_scale); use_default_grad_scale, balance_parameter_opt_between_cards);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
......
...@@ -40,7 +40,8 @@ class ParallelExecutor { ...@@ -40,7 +40,8 @@ class ParallelExecutor {
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& loss_var_name, Scope* scope, const std::string& loss_var_name, Scope* scope,
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale); bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards);
~ParallelExecutor(); ~ParallelExecutor();
......
...@@ -24,8 +24,7 @@ namespace framework { ...@@ -24,8 +24,7 @@ namespace framework {
class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public OpProtoAndCheckerMaker {
public: public:
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "").AsDuplicable(); AddInput("X", "").AsDuplicable();
AddOutput("Out", ""); AddOutput("Out", "");
AddComment(""); AddComment("");
......
...@@ -166,6 +166,8 @@ function(op_library TARGET) ...@@ -166,6 +166,8 @@ function(op_library TARGET)
# NOTE(*): activation use macro to regist the kernels, set use_op manually. # NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation") if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif() endif()
...@@ -268,6 +270,11 @@ foreach(src ${READER_LIBRARY}) ...@@ -268,6 +270,11 @@ foreach(src ${READER_LIBRARY})
set(OP_LIBRARY ${src} ${OP_LIBRARY}) set(OP_LIBRARY ${src} ${OP_LIBRARY})
endforeach() endforeach()
add_subdirectory(detection)
foreach(src ${DETECTION_LIBRARY})
set(OP_LIBRARY ${src} ${OP_LIBRARY})
endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
......
...@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
// TODO(typhoonzero): support both inference value and indices. // TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)"); AddInput("Out", "The network output of topk (inferences)");
AddInput("Indices", "The the network output of topk (indices)"); AddInput("Indices", "The the network output of topk (indices)");
......
...@@ -19,19 +19,18 @@ limitations under the License. */ ...@@ -19,19 +19,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \ class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \ public: \
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \ void Make() override { \
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \ AddInput("X", "Input of " #OP_NAME "operator"); \
AddInput("X", "Input of " #OP_NAME "operator"); \ AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \ AddAttr<bool>("use_mkldnn", \
AddAttr<bool>("use_mkldnn", \ "(bool, default false) Only used in mkldnn kernel") \
"(bool, default false) Only used in mkldnn kernel") \ .SetDefault(false); \
.SetDefault(false); \ AddComment(#OP_COMMENT); \
AddComment(#OP_COMMENT); \ } \
} \
} }
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
...@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$ ...@@ -204,8 +203,7 @@ $$out = \frac{x}{1 + |x|}$$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator"); AddInput("X", "Input of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator"); AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f); AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
...@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$ ...@@ -220,8 +218,7 @@ $out = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator"); AddInput("X", "Input of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator"); AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f); AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
...@@ -242,8 +239,7 @@ $$ ...@@ -242,8 +239,7 @@ $$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator"); AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator"); AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink") AddAttr<float>("threshold", "The value of threshold for HardShrink")
...@@ -265,8 +261,7 @@ $$ ...@@ -265,8 +261,7 @@ $$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker { class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator"); AddInput("X", "Input of BRelu operator");
AddOutput("Out", "Output of BRelu operator"); AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu") AddAttr<float>("t_min", "The min marginal value of BRelu")
...@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$ ...@@ -284,8 +279,7 @@ $out = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator"); AddInput("X", "Input of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator"); AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu") AddAttr<float>("threshold", "The threshold value of SoftRelu")
...@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ ...@@ -301,8 +295,7 @@ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker { class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator"); AddInput("X", "Input of ELU operator");
AddOutput("Out", "Output of ELU operator"); AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f); AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
...@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$ ...@@ -320,8 +313,7 @@ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator"); AddInput("X", "Input of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator"); AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6") AddAttr<float>("threshold", "The threshold value of Relu6")
...@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$ ...@@ -337,8 +329,7 @@ $out = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker { class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator"); AddInput("X", "Input of Pow operator");
AddOutput("Out", "Output of Pow operator"); AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f); AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
...@@ -353,8 +344,7 @@ $out = x^{factor}$ ...@@ -353,8 +344,7 @@ $out = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker { class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator"); AddInput("X", "Input of STanh operator");
AddOutput("Out", "Output of STanh operator"); AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input") AddAttr<float>("scale_a", "The scale parameter of a for the input")
...@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ ...@@ -372,8 +362,7 @@ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator"); AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator"); AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation") AddAttr<float>("threshold", "The threshold location of activation")
...@@ -394,8 +383,7 @@ $$ ...@@ -394,8 +383,7 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator"); AddInput("X", "Input of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator"); AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid") AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
...@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation. ...@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker { class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator"); AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator"); AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f); AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
...@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient"); AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
......
...@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -67,8 +67,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment"); AddInput("Moment", "(Tensor) Second moment");
......
...@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -80,8 +80,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker { class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate"); AddInput("LearningRate", "(Tensor) Learning rate");
......
...@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -74,8 +74,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("LearningRate", "(Tensor) Learning rate"); AddInput("LearningRate", "(Tensor) Learning rate");
......
...@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -123,8 +123,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to " "(std::vector<LodTensor>) A vector of tensors that is going to "
"be casted to a big LoDTensor."); "be casted to a big LoDTensor.");
......
...@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase { ...@@ -94,8 +94,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable " "(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.") "could be LoDTensor, SelectedRows or LoDTensorArray.")
......
...@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AssignValueOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) Output tensor of assign_value operator."); AddOutput("Out", "(Tensor) Output tensor of assign_value operator.");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
"(vector<int>) " "(vector<int>) "
......
...@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -50,8 +50,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker { class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AucOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Out", AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]." "A floating point 2D tensor, values are in the range [0, 1]."
"Each row is sorted in descending order. This input should be the" "Each row is sorted in descending order. This input should be the"
......
...@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -111,8 +111,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "(Tensor), The parameter to be accumulated."); AddInput("param", "(Tensor), The parameter to be accumulated.");
AddInput("in_sum_1", AddInput("in_sum_1",
"(Tensor), A tensor used to store the parameter " "(Tensor), A tensor used to store the parameter "
......
...@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -126,8 +126,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<bool>("is_test", "").SetDefault(false); AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9); AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "") AddAttr<float>("epsilon", "")
......
...@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,7 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker { class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() final {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) Tensor " "(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size"); "whose input_dim_idx'th dimension specifies the batch_size");
...@@ -68,7 +67,11 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,7 +67,11 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("output_dim_idx", AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension") "(int, default 0) The index of output's batch size dimension")
.SetDefault(0); .SetDefault(0);
Apply();
} }
protected:
virtual void Apply() = 0;
}; };
} // namespace operators } // namespace operators
......
...@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -134,8 +134,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker { class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", AddInput("Ids",
"(LodTensorArray)" "(LodTensorArray)"
"score of the candidate words in each step"); "score of the candidate words in each step");
......
...@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) { ...@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step"); AddInput("pre_ids", "ids in previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]"); AddInput("ids", "a LoDTensor of shape of [None,k]");
......
...@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel { ...@@ -41,8 +41,7 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BilinearInterpOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of bilinear interpolation, " "(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"); "This is a 4-D tensor with shape of (N x C x h x w)");
......
...@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of bilinear_tensor_product operator."); AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator."); AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight", AddInput("Weight",
......
...@@ -21,8 +21,7 @@ namespace operators { ...@@ -21,8 +21,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of cast op"); AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op"); AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type"); AddAttr<int>("out_dtype", "output data type");
......
...@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase { ...@@ -50,8 +50,7 @@ class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kChannel, AddInput(kChannel,
"The Channel Variable that should be closed by" "The Channel Variable that should be closed by"
" the ChannelClose Op."); " the ChannelClose Op.");
......
...@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase { ...@@ -91,8 +91,7 @@ class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput(kOutput, AddOutput(kOutput,
"The object of a Channel type created by ChannelCreate Op."); "The object of a Channel type created by ChannelCreate Op.");
AddAttr<int>("capacity", "The size of the buffer of Channel.") AddAttr<int>("capacity", "The size of the buffer of Channel.")
......
...@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase { ...@@ -72,8 +72,7 @@ class ChannelRecvOp : public framework::OperatorBase {
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelRecvOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(Channel, AddInput(Channel,
"(Channel) A variable which \"receives\" the a value sent" "(Channel) A variable which \"receives\" the a value sent"
"to it by a channel_send op.") "to it by a channel_send op.")
......
...@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase { ...@@ -57,8 +57,7 @@ class ChannelSendOp : public framework::OperatorBase {
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChannelSendOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(Channel, AddInput(Channel,
"(Channel) A variable which \"sends\" the passed in value to " "(Channel) A variable which \"sends\" the passed in value to "
"a listening receiver.") "a listening receiver.")
......
...@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker { class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference", AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). " "(Tensor, default: Tensor<int64_t>). "
"Predictions from the network."); "Predictions from the network.");
......
...@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input of clip_by_norm op." "(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9]."); "The number of dimensions must be between [1, 9].");
......
...@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -38,8 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType> template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker { class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor)The input of clip op." "(Tensor)The input of clip op."
"The number of dimensions must be between [1, 9]."); "The number of dimensions must be between [1, 9].");
......
...@@ -21,8 +21,7 @@ namespace operators { ...@@ -21,8 +21,7 @@ namespace operators {
template <typename OpComment> template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment; OpComment comment;
AddInput("X", AddInput("X",
string::Sprintf("(LoDTensor) the left hand operand of %s operator", string::Sprintf("(LoDTensor) the left hand operand of %s operator",
......
...@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator."); AddOutput("Out", "Output tensor of concat operator.");
AddAttr<int>("axis", AddAttr<int>("axis",
......
...@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -108,8 +108,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The conditional variable of this operator. If X is empty, the " "The conditional variable of this operator. If X is empty, the "
"whole sub-block will not be executed.") "whole sub-block will not be executed.")
......
...@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -106,8 +106,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
library); library);
} }
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Conv2DOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
...@@ -200,8 +199,7 @@ $$ ...@@ -200,8 +199,7 @@ $$
)DOC"); )DOC");
} }
Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Conv3DOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
......
...@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim, ...@@ -60,12 +60,12 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class ConvOp : public framework::OperatorWithKernel { class ConvOp : public framework::OperatorWithKernel {
......
...@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, " "(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
"where B is the batch size and M is the data dimension."); "where B is the batch size and M is the data dimension.");
......
...@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -84,9 +84,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
layout_, library_); layout_, library_);
} }
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto, void Conv2DTransposeOpMaker::Make() {
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution transpose operator. " "(Tensor) The input tensor of convolution transpose operator. "
...@@ -168,9 +166,7 @@ Example: ...@@ -168,9 +166,7 @@ Example:
)DOC"); )DOC");
} }
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto, void Conv3DTransposeOpMaker::Make() {
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator." "(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is " "The format of input tensor is NCDHW. Where N is batch size, C is "
......
...@@ -30,12 +30,12 @@ using DDim = framework::DDim; ...@@ -30,12 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code. // operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
class ConvTransposeOp : public framework::OperatorWithKernel { class ConvTransposeOp : public framework::OperatorWithKernel {
......
...@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel { ...@@ -62,8 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op."); AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op.");
AddOutput("Out", "The output of cos_sim op."); AddOutput("Out", "The output of cos_sim op.");
......
...@@ -18,8 +18,7 @@ namespace paddle { ...@@ -18,8 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker { class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission", AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape " "(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
"[N x D] where N is the size of the mini-batch and D is the total " "[N x D] where N is the size of the mini-batch and D is the total "
......
...@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -52,8 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker { class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CropOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"The input of pad op. " "The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)."); "The input should be a k-D tensor(k > 0 and k < 7).");
......
...@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D]," "(Tensor, default Tensor<float>), a 2-D tensor with shape [N x D],"
" where N is the batch size and D is the number of classes. " " where N is the batch size and D is the number of classes. "
......
...@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -44,8 +44,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is " "(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length."); "[Lp, 1], where Lp is the sum of all input sequences' length.");
......
...@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel { ...@@ -29,8 +29,7 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CumsumOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Cumsum operator"); AddInput("X", "Input of Cumsum operator");
AddOutput("Out", "Output of Cumsum operator"); AddOutput("Out", "Output of Cumsum operator");
AddAttr<int>("axis", AddAttr<int>("axis",
......
...@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -62,8 +62,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment"); AddInput("Moment", "(Tensor) Second moment");
......
...@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase { ...@@ -34,8 +34,7 @@ class DeleteVarOp : public framework::OperatorBase {
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker { class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of delete op").AsDuplicable(); AddInput("X", "The input of delete op").AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Delete Operator. Delete Operator.
......
set(LOCAL_DETECTION_LIBS)
function(detection_library TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(options "")
set(common_deps op_registry)
set(pybind_flag 0)
cmake_parse_arguments(detection_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
op_library(${TARGET_NAME} SRCS ${detection_library_SRCS} DEPS ${common_deps} ${detection_library_DEPS})
set(LOCAL_DETECTION_LIBS
${TARGET_NAME}
${LOCAL_DETECTION_LIBS}
PARENT_SCOPE)
endfunction()
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu)
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
target_assign_op.cu)
# Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
...@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -182,8 +182,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"DistMat", "DistMat",
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape " "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
......
...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { ...@@ -60,8 +60,7 @@ class BoxCoderOp : public framework::OperatorWithKernel {
class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker { class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BoxCoderOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"PriorBox", "PriorBox",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
......
...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/box_coder_op.h" #include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/iou_similarity_op.h" #include "paddle/fluid/operators/detection/iou_similarity_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,8 +42,7 @@ class IOUSimilarityOp : public framework::OperatorWithKernel { ...@@ -42,8 +42,7 @@ class IOUSimilarityOp : public framework::OperatorWithKernel {
class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker { class IOUSimilarityOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
IOUSimilarityOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, " "Box list X is a 2-D LoDTensor with shape [N, 4] holds N boxes, "
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/iou_similarity_op.h" #include "paddle/fluid/operators/detection/iou_similarity_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -253,8 +253,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -253,8 +253,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MineHardExamplesOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"ClsLoss", "ClsLoss",
"(Tensor, default Tensor<float>), The classification loss with shape " "(Tensor, default Tensor<float>), The classification loss with shape "
......
...@@ -309,8 +309,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> { ...@@ -309,8 +309,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("BBoxes", AddInput("BBoxes",
"(Tensor) A 3-D Tensor with shape [N, M, 4] represents the " "(Tensor) A 3-D Tensor with shape [N, M, 4] represents the "
"predicted locations of M bounding bboxes, N is the batch size. " "predicted locations of M bounding bboxes, N is the batch size. "
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/prior_box_op.h" #include "paddle/fluid/operators/detection/prior_box_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -79,8 +79,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -79,8 +79,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PriorBoxOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(Tensor, default Tensor<float>), " "(Tensor, default Tensor<float>), "
"the input feature data of PriorBoxOp, The layout is NCHW."); "the input feature data of PriorBoxOp, The layout is NCHW.");
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/prior_box_op.h" #include "paddle/fluid/operators/detection/prior_box_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/target_assign_op.h" #include "paddle/fluid/operators/detection/target_assign_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -65,8 +65,7 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
TargetAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. " "(LoDTensor), This input is a 3D LoDTensor with shape [M, P, K]. "
"Some elements in X will be assigned to Out based on the " "Some elements in X will be assigned to Out based on the "
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/target_assign_op.h" #include "paddle/fluid/operators/detection/target_assign_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("DetectRes", AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: " "detections. Each row has 6 values: "
......
...@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,7 @@ class DropoutOp : public framework::OperatorWithKernel {
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of dropout op."); AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op."); AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
......
...@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel { ...@@ -49,8 +49,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Hyps", AddInput("Hyps",
"(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) " "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
"The indices for hypothesis strings."); "The indices for hypothesis strings.");
......
...@@ -14,26 +14,8 @@ limitations under the License. */ ...@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "Out = X + Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,26 +14,8 @@ limitations under the License. */ ...@@ -14,26 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "Out = X / Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
ops::ElementwiseDivOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,25 +14,8 @@ limitations under the License. */ ...@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = max(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
ops::ElementwiseMaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_max, elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,25 +14,8 @@ limitations under the License. */ ...@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Max", "Out = min(X, Y)");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
ops::ElementwiseMinOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_min, elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -14,27 +14,8 @@ limitations under the License. */ ...@@ -14,27 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "Out = X \\odot\\ Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y");
ops::ElementwiseMulOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { ...@@ -54,8 +54,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() final {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op."); AddOutput("Out", "The output of elementwise op.");
...@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,12 +63,12 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
"for broadcasting Y onto X.") "for broadcasting Y onto X.")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
comment_ = R"DOC( AddComment(string::Sprintf(R"DOC(
Limited Elementwise {name} Operator. Limited Elementwise %s Operator.
The equation is: The equation is:
$${equation}$$ $$%s$$
$X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be $X$ is a tensor of any dimension and the dimensions of tensor $Y$ must be
smaller than or equal to the dimensions of $X$. smaller than or equal to the dimensions of $X$.
...@@ -100,26 +99,13 @@ For example ...@@ -100,26 +99,13 @@ For example
Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details) Either of the inputs $X$ and $Y$ or none can carry the LoD (Level of Details)
information. However, the output only shares the LoD information with input $X$. information. However, the output only shares the LoD information with input $X$.
)DOC"; )DOC",
AddComment(comment_); GetName(), GetEquation()));
} }
protected: protected:
std::string comment_; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0;
void Replace(std::string* src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src->find(from); pos != std::string::npos;
pos = src->find(from, pos + len_to)) {
src->replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string equation) {
Replace(&comment_, "{name}", name);
Replace(&comment_, "{equation}", equation);
}
}; };
class ElementwiseOpGrad : public framework::OperatorWithKernel { class ElementwiseOpGrad : public framework::OperatorWithKernel {
...@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -152,3 +138,16 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
...@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise_pow_op.h" #include "paddle/fluid/operators/elementwise_pow_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ElementwisePowOpMaker : public ElementwiseOpMaker { class ElementwisePowOpMaker : public ElementwiseOpMaker {
public: protected:
ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker) std::string GetName() const override { return "Pow"; }
: ElementwiseOpMaker(proto, op_checker) { std::string GetEquation() const override { return "Out = X ^ Y"; }
SetComment("Pow", "Out = X ^ Y");
AddComment(comment_);
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -14,25 +14,8 @@ limitations under the License. */ ...@@ -14,25 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "Out = X - Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_sub, ops::ElementwiseOp, REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y");
ops::ElementwiseSubOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]." "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded."); "X is the input to be expanded.");
......
...@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -72,8 +72,7 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
layout, library); layout, library);
} }
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) void FCOpMaker::Make() {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op."); AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
......
...@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel { ...@@ -45,7 +45,7 @@ class FCOpGrad : public framework::OperatorWithKernel {
class FCOpMaker : public framework::OpProtoAndCheckerMaker { class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FCOpMaker(OpProto* proto, OpAttrChecker* op_checker); void Make() override;
}; };
} // namespace operators } // namespace operators
......
...@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -66,8 +66,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of feed op"); AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op"); AddOutput("Out", "The output of feed op");
AddAttr<int>("col", "(int) The column of feed"); AddAttr<int>("col", "(int) The column of feed");
......
...@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -66,8 +66,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker { class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fetch op"); AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op"); AddOutput("Out", "The output of fetch op");
AddAttr<int>("col", "(int) The column of fetch"); AddAttr<int>("col", "(int) The column of fetch");
......
...@@ -30,9 +30,8 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -30,9 +30,8 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
}; };
class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public: protected:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Apply() override {
: BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
......
...@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -59,8 +59,7 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
......
...@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase { ...@@ -82,8 +82,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker { class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment(R"DOC(Fill operator AddComment(R"DOC(Fill operator
Fill an tensor with `value` and `shape`. The type of the tensor is specify by Fill an tensor with `value` and `shape`. The type of the tensor is specify by
......
...@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FillZerosLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fill-zeros-like op."); AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros."); AddOutput("Out", "The variable will be filled up with zeros.");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -64,8 +64,7 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,7 @@ class FTRLOp : public framework::OperatorWithKernel {
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
FTRLOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", AddInput("Param",
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated."); "Input parameter value that has to be updated.");
......
...@@ -67,8 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -67,8 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GatherOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op"); AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op"); AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of gather op"); AddOutput("Out", "The output of gather op");
......
...@@ -32,9 +32,8 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp { ...@@ -32,9 +32,8 @@ class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
}; };
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker { class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public: protected:
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) void Apply() override {
: BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<float>("mean", AddAttr<float>("mean",
"(float, default 0.0) " "(float, default 0.0) "
"mean of random tensor.") "mean of random tensor.")
......
...@@ -70,8 +70,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -70,8 +70,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GaussianRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker) void Make() override {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "Output matrix of gaussian random op"); AddOutput("Out", "Output matrix of gaussian random op");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册