提交 c23af80a 编写于 作者: Y Yu Yang

Change macro

上级 e3a642e0
......@@ -79,6 +79,7 @@ class GradOpDescMakerBase {
class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<OpDescBind> operator()() const { return {this->Apply()}; }
protected:
......@@ -86,6 +87,9 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual OpDescBind Apply() const {
OpDescBind grad;
......
......@@ -24,14 +24,27 @@ limitations under the License. */
#include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace paddle {
namespace framework {
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename... ARGS>
struct OperatorRegistrar {
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
......@@ -70,19 +83,6 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
};
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar {
public:
......@@ -138,33 +138,43 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
/**
* Macro to register Operator.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
#define VA_ARGS(...) , ##__VA_ARGS__
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
class _OpGradClass_##op_type##_ : public grad_op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \
__VA_ARGS__)> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
/**
* Macro to register Operator.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker { \
using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \
\
protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \
}; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class)
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
/**
* Macro to register OperatorKernel.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册