#pragma once #include #include #include #include #include #include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/operator.h" namespace paddle { namespace framework { // helper class to set attribute type struct AttrTypeHelper { template static void SetAttrType(AttrProto* attr); static Attribute GetAttrValue(const AttrDesc& attr_desc) { switch (attr_desc.type()) { case paddle::framework::AttrType::INT: { return attr_desc.i(); } case paddle::framework::AttrType::FLOAT: { return attr_desc.f(); } case paddle::framework::AttrType::STRING: { return attr_desc.s(); } case paddle::framework::AttrType::INTS: { std::vector val(attr_desc.ints_size()); for (int i = 0; i < attr_desc.ints_size(); ++i) { val[i] = attr_desc.ints(i); } return val; } case paddle::framework::AttrType::FLOATS: { std::vector val(attr_desc.floats_size()); for (int i = 0; i < attr_desc.floats_size(); ++i) { val[i] = attr_desc.floats(i); } return val; } case paddle::framework::AttrType::STRINGS: { std::vector val(attr_desc.strings_size()); for (int i = 0; i < attr_desc.strings_size(); ++i) { val[i] = attr_desc.strings(i); } return val; } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); } }; // this class not only make proto but also init attribute checkers. class OpProtoAndCheckerMaker { public: OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : proto_(proto), op_checker_(op_checker) {} ~OpProtoAndCheckerMaker() { PADDLE_ENFORCE(validated_, "should call Validate after build"); } void Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); } protected: void AddInput(const std::string& name, const std::string& comment, bool multiple = false) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; input->set_multiple(multiple); if (multiple) { SetHasMultipleInput(); } } void AddInputs(const std::string& name, const std::string& comment) { AddInput(name, comment, true); } void AddOutput(const std::string& name, const std::string& comment, bool temporary = false, bool multiple = false) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; output->set_multiple(multiple); if (multiple) { SetHasMultipleOutput(); } output->set_temporary(temporary); if (temporary) { SetHasTemporaryOutput(); } } void AddOutputs(const std::string& name, const std::string& comment, bool temporary = false) { AddOutput(name, comment, temporary, true); } template TypedAttrChecker& AddAttr(const std::string& name, const std::string& comment, bool generated = false) { auto attr = proto_->mutable_attrs()->Add(); *attr->mutable_name() = name; *attr->mutable_comment() = comment; attr->set_generated(generated); AttrTypeHelper::SetAttrType(attr); return op_checker_->AddAttrChecker(name); } void AddComment(const std::string& comment) { *(proto_->mutable_comment()) = comment; } private: void SetHasMultiple(const std::string& in_out, bool* flag) { if (!*flag) { AddAttr>(in_out + "_format", "The multiple index of " + in_out + "\n" R"DOC( This attribute is used by Paddle core framework. Paddle's Op support each input or output could be a list of variable. This attribute is used to show how that list organized. e.g. input = ["a", "b", "c", "d", "e", "f"] input_format = [0, 4, 5, 6] means The number of all input variables this op is six, and they are segmented into three inputs. The first input is input[0:4], second is input[4:5], third is input[5:6]. )DOC", /*generated*/ true); *flag = true; } } void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); } void SetHasMultipleOutput() { SetHasMultiple("output", &has_multiple_output_); } void SetHasTemporaryOutput() { if (!has_temporary_output_) { AddAttr>("temporary_index", R"DOC(The temporary index of output. Not all output of Paddle Op is used by user. For faster computation, each op could output some its internal state to other op, other op could take that output to make compute faster. Add a mark to which output is temporary is helpful for future optimization. )DOC", /*generated*/ true) .SetDefault(std::vector()); has_temporary_output_ = true; } } void CheckNoDuplicatedInOutAttrs() { std::unordered_set names; auto checker = [&](const std::string& name) { PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); names.insert(name); }; for (auto& attr : proto_->attrs()) { checker(attr.name()); } for (auto& input : proto_->inputs()) { checker(input.name()); } for (auto& output : proto_->outputs()) { checker(output.name()); } } OpProto* proto_; OpAttrChecker* op_checker_; bool validated_{false}; bool has_multiple_input_{false}; bool has_multiple_output_{false}; bool has_temporary_output_{false}; }; class OpRegistry { using OpCreator = std::function; public: template static void RegisterOp(const std::string& op_type) { creators()[op_type] = [] { return new OpType; }; OpProto& op_proto = protos()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type]; auto maker = ProtoMakerType(&op_proto, &op_checker); maker.Validate(); *op_proto.mutable_type() = op_type; PADDLE_ENFORCE( op_proto.IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_proto.InitializationErrorString()); } static OperatorPtr CreateOp(const OpDesc& op_desc) { //! Create a OpPtr by type. std::string op_type = op_desc.type(); OperatorPtr op(creators().at(op_type)()); //! Fill op's data member. Not use constructor because it will be noising //! for Op developer. const OpProto& op_proto = protos().at(op_type); op->type_ = op_desc.type(); // set op's inputs_ from desc. op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::back_inserter(op->inputs_)); // set op's outputs_ from desc. op->outputs_.reserve((size_t)op_desc.outputs_size()); std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), std::back_inserter(op->outputs_)); //! Fill attrs, and validate attrs. for (auto& attr : op_desc.attrs()) { op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); } op_checkers().at(op_type).Check(op->attrs_); //! Convert Temporary variable name to an unique variable name. GenerateTempVariableName(op.get()); // set argument offsets stored in op. CreateInOutOffsetMap(op, op_proto); //! Other op's custom Init for a complex Op. For simple Op, the Init //! method do nothing. op->Init(); return op; } // init op.in_out_idxs_ to accelerate argument's offset lookup. static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) { op->CreateInOutOffsetMap(proto); } static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; }; private: static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); for (auto& outname : op->outputs_) { if (outname == OperatorBase::TMP_VAR_NAME()) { outname += op->type_; outname += "@"; outname += std::to_string(gUniqId.fetch_add(1)); } } } static std::unordered_map& creators() { static std::unordered_map creators_; return creators_; } static std::unordered_map& op_checkers() { static std::unordered_map op_checkers_; return op_checkers_; }; }; template class OpRegisterHelper { public: OpRegisterHelper(const char* op_type) { OpRegistry::RegisterOp(op_type); } }; /** * check if MACRO is used in GLOBAL NAMESPACE. */ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg) /** * Macro to Register Operator. */ #define REGISTER_OP(__op_type, __op_class, __op_maker_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \ "REGISTER_OP must be in global namespace"); \ static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \ __op_register_##__op_type##__(#__op_type); \ int __op_register_##__op_type##_handle__() { return 0; } /** * Macro to Register OperatorKernel. */ #define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ struct __op_kernel_register__##type##__ { \ __op_kernel_register__##type##__() { \ ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ key.place_ = PlaceType(); \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ .reset(new __VA_ARGS__()); \ } \ }; \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } // (type, KernelType) #define REGISTER_OP_GPU_KERNEL(type, ...) \ REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) // (type, KernelType) #define REGISTER_OP_CPU_KERNEL(type, ...) \ REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** * Macro to mark what Operator and Kernel we will use and tell the compiler to * link them into target. */ #define USE_OP_WITHOUT_KERNEL(op_type) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __use_op_without_kernel_##op_type, \ "USE_OP_WITHOUT_KERNEL must be in global namespace"); \ extern int __op_register_##op_type##_handle__(); \ static int __use_op_ptr_##op_type##_without_kernel__ \ __attribute__((unused)) = __op_register_##op_type##_handle__() #define USE_OP_KERNEL(op_type, DEVICE_TYPE) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ "USE_OP_KERNEL must be in global namespace"); \ extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \ static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \ __attribute__((unused)) = \ __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() // use Operator with only cpu kernel. #define USE_OP_CPU(op_type) \ USE_OP_WITHOUT_KERNEL(op_type); \ USE_OP_KERNEL(op_type, CPU) #ifdef PADDLE_ONLY_CPU #define USE_OP(op_type) USE_OP_CPU(op_type) #else #define USE_OP(op_type) \ USE_OP_CPU(op_type); \ USE_OP_KERNEL(op_type, GPU) #endif } // namespace framework } // namespace paddle