提交 ab08575a 编写于 作者: F fengjiayi

WIP

上级 55fac551
...@@ -76,8 +76,16 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -76,8 +76,16 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); auto it = op_info_map().find(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type);
it = op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
OperatorBase* grad_op = it->second.creator_();
grad_op->type_ = grad_op_type; grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_; grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <type_traits> #include <type_traits>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
...@@ -174,6 +175,15 @@ Add a mark to which output is temporary is helpful for future optimization. ...@@ -174,6 +175,15 @@ Add a mark to which output is temporary is helpful for future optimization.
bool has_temporary_output_{false}; bool has_temporary_output_{false};
}; };
class NOPMaker : public OpProtoAndCheckerMaker {};
struct OpInfo {
std::function creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
};
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>; using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
...@@ -181,18 +191,25 @@ class OpRegistry { ...@@ -181,18 +191,25 @@ class OpRegistry {
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type,
op_creators()[op_type] = [] { return new OpType; }; const std::string& grad_op_type) {
OpAttrChecker& op_checker = op_checkers()[op_type]; PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
OpProto& op_proto = protos()[op_type]; "'%s' is registered more than once.", op_type);
auto maker = ProtoMakerType(&op_proto, &op_checker); OpInfo op_info;
op_info.creator_ = [] { return new OpType; };
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.op_checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.op_checker_);
maker.Validate(); maker.Validate();
*op_proto.mutable_type() = op_type; *op_info.proto_->mutable_type() = op_type;
PADDLE_ENFORCE( PADDLE_ENFORCE(
op_proto.IsInitialized(), op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString()); op_type, op_info.proto_->InitializationErrorString());
//======will be refactored in following PRs============//
VarIndexMaps()[op_type].reset(new VarIndexMap()); VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type]; auto& varmap = *VarIndexMaps()[op_type];
int idx = 0; int idx = 0;
...@@ -203,30 +220,26 @@ class OpRegistry { ...@@ -203,30 +220,26 @@ class OpRegistry {
for (auto& var : op_proto.outputs()) { for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++; varmap[var.name()] = idx++;
} }
//================================================//
} }
op_info_map.insert(std::make_pair(op_type, op_info));
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = op_creators().find(type); auto it = op_info_map().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(it != op_info_map().end(), "'%s' has not been registered.",
"Operator %s cannot be found.", type); type);
auto op = op_create_it->second(); auto op = it->second.creator_();
op->type_ = type; op->type_ = type;
op->inputs_ = inputs; op->inputs_ = inputs;
op->outputs_ = outputs; op->outputs_ = outputs;
op->attrs_ = attrs; op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_); it->second.checker_->Check(op->attrs_);
GenerateTempVariableName(op); GenerateTempVariableName(op);
...@@ -268,14 +281,9 @@ class OpRegistry { ...@@ -268,14 +281,9 @@ class OpRegistry {
return grad_op; return grad_op;
} }
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<const std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<const std::string, const OpInfo> op_info_map_;
return protos_; return op_info_map_;
}
static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
} }
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
...@@ -284,17 +292,7 @@ class OpRegistry { ...@@ -284,17 +292,7 @@ class OpRegistry {
return maps_; return maps_;
} }
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}
private: private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
}
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& outname : op->outputs_) {
...@@ -323,16 +321,9 @@ class Registrar { ...@@ -323,16 +321,9 @@ class Registrar {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
explicit OpRegistrar(const char* op_type) { OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistrar(const char* op_type, const char* grad_op_type) {
} OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type);
};
template <typename GradOpType>
class GradOpRegistrar : public Registrar {
public:
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
...@@ -358,30 +349,21 @@ class OpKernelRegistrar : public Registrar { ...@@ -358,30 +349,21 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
__op_registrar_##op_type##__(#op_type); \ __op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
return 0; \ return 0; \
} }
/** #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
* Macro to register Gradient Operator. REGISTER_OP(op_type, op_class, op_maker_class, )
*/
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ #define REGISTER_GRADIENT_OP(op_type, op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ REGISTER_OP(op_type, op_class, ::paddle::framework::NOPMaker, )
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
...@@ -400,10 +382,12 @@ class OpKernelRegistrar : public Registrar { ...@@ -400,10 +382,12 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to Forbid user register Gradient Operator. * Macro to Forbid user register Gradient Operator.
*/ */
/*
#define NO_GRADIENT(op_type) \ #define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \ __reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace") "NO_GRADIENT must be called in global namespace")
*/
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ #define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
...@@ -423,23 +407,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -423,23 +407,6 @@ class OpKernelRegistrar : public Registrar {
static int use_op_itself_##op_type##_ __attribute__((unused)) = \ static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type() TouchOpRegistrar_##op_type()
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ #define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
...@@ -459,18 +426,13 @@ class OpKernelRegistrar : public Registrar { ...@@ -459,18 +426,13 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif #endif
#define USE_NO_GRAD_OP(op_type) \ #define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
#define USE_CPU_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \ USE_OP_DEVICE_KERNEL(op_type, CPU);
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_GRADIENT(op_type) USE_OP_KERNEL(op_type)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -173,13 +173,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -173,13 +173,13 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &protos = OpRegistry::protos(); auto &op_info_map = OpRegistry::op_info_map();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), const OpProto *proto = it->second.proto_;
"OpProto must all be initialized"); PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str; std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&str), PADDLE_ENFORCE(proto->SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str)); ret_values.push_back(py::bytes(str));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册