{# ----------------------------- op maker ----------------------------------- #} {% macro op_maker(api) %} {% set api_name = api["name"] %} class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { {% filter indent(4, True) %} {% for input in api["inputs"] %} {{add_input(loop.index0, input, api_name)}}; {% endfor %} {% for output in api["outputs"] %} {{add_output(loop.index0, output, api_name)}}; {% endfor %} {% for attr in api["attrs"] %} {% if attr["name"] in api["kernel"]["param"] %} {{add_attr(loop.index0, attr, api_name)}}; {% endif %} {% endfor %} {% endfilter %} AddComment(R"DOC( TODO: Documentation of {{api_name}} op. )DOC"); } }; {% endmacro %} {# add input, it could be duplicable or dispensable #} {% macro add_input(i, input, op_name) %}{# inline #} {% set name = input["name"] %} {% set typename = input["typename"] %} AddInput({{name| to_opmaker_name}}, "({{typename}}), input {{i}} of {{op_name}} op.") {%- if typename is vec %} .AsDuplicable() {%- endif %} {%- if input["optional"] %} .AsDispensable() {%- endif %} {%- endmacro %} {# add output, it could be duplicable or intermediate, however, optional output is not supported #} {% macro add_output(i, output, op_name) %}{# inline #} {% set name = output["name"] %} {% set typename = output["typename"] %} {% set is_intermediate = output["intermediate"] %} AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name}} op.") {%- if typename is vec %} .AsDuplicable() {%- endif %} {%- if is_intermediate %} .AsIntermediate() {%- endif %} {%- endmacro %} {# add attribute, and process default value if needed #} {% macro add_attr(i, attr, op_name) %}{# inline #} {% set name = attr["name"] %} {% set typename = attr["typename"] %} {% if typename is scalar %} AddInput("{{name | to_pascal_case}}Tensor", "attribute {{i}} for {{op_name}} op from 0D Tensor.") .AsDispensable(); {% elif typename == "IntArray" %}{# the type has been renamed #} AddInput("{{name | to_pascal_case}}Tensor", "attribute {{i}} for {{op_name}} op from 1D integer Tensor.") .AsDispensable(); AddInput("{{name | to_pascal_case}}TensorList", "attribute {{i}} for {{op_name}} op from list fo 0D integer Tensors.") .AsDuplicable() .AsDispensable(); {% endif %} AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_type}}), attribute {{i}} for {{op_name}} op.") {% if "default_value" in attr %} .SetDefault({{process_default_value(attr)}}) {%- endif %} {%- endmacro %} {# process default value for attributes, some attribute has different types and different default values in api & opmaker #} {% macro process_default_value(attr) %}{# inline #} {% set default_value = attr["default_value"] %} {% set typename = attr["typename"] %} {% if typename == "DataType" %}{# convert back to VarType #} static_cast(framework::TransToProtoVarType(experimental::{{default_value}})) {%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#} static_cast(experimental::{{default_value}}) {%- elif typename == "Place" %}{# construct a Place to get the type #} static_cast(phi::Place({{"phi::" if not default_value is initializer_list}}{{default_value}}).GetType()) {%- else %}{# pass through as-is #} {{default_value}} {%- endif %} {%- endmacro %} {# --------------------------------------- name mapping ---------------------------------------------- #} {% macro name_map(api) %} KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { {% set kernel_args = api["kernel"]["param"] %} {{get_input_list(api["inputs"], kernel_args)}}; paddle::small_vector attrs; {% for attr in api["attrs"]%} {% filter indent(2)%} {{get_an_attr(attr)}}; {% endfilter %} {% endfor %} {{get_output_list(api["outputs"], kernel_args)}}; {% if api["kernel"]["func"] | length == 1 %} KernelSignature sig("{{api["name"]}}", std::move(inputs), std::move(attrs), std::move(outputs)); return sig; {% else %}{# it has kernel for selected rows #} const char* kernel_name = ctx.IsSelectedRowsInput({{kernel_args[0] | to_opmaker_name_cstr}}) ? "{{api["kernel"]["func"][1]}}" : "{{api["kernel"]["func"][0]}}"; KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); return sig; {%endif%} } /* ****************************************************************** NOTE: The following codes are for 'get_compat_kernel_signature.py' All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: {{api | cartesian_prod_mapping}} ****************************************************************** */ {% endmacro %} {% macro register_name_map(api) %} PD_REGISTER_ARG_MAPPING_FN({{api["name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping); {%- endmacro %} {% macro get_input_list(inputs, kernel_args) %}{# inline #} paddle::small_vector inputs { {%- for input in inputs %} {%- if input["name"] in kernel_args %} {{input["name"] | to_opmaker_name_cstr}}{{", " if not loop.last}} {%- endif %} {%- endfor %} } {%- endmacro %} {% macro get_an_attr(attr) %}{# inline #} {% set typename = attr["typename"] %} {% set name = attr["name"] %} {% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #} attrs.emplace_back( ctx.HasInput("{{name | to_pascal_case}}") ? "{{name | to_pascal_case}}Tensor" : "{{name}}" ) {%- elif typename == "IntArray" %} attrs.emplace_back( ctx.HasInput("{{name | to_pascal_case}}Tensor") ? "{{name | to_pascal_case}}Tensor" : ctx.InputSize("{{name | to_pascal_case}}TensorList") > 0 ? "{{name | to_pascal_case}}TensorList" : "{{name}}" ) {%- else %} attrs.emplace_back("{{name}}") {%- endif %} {%- endmacro %} {% macro get_output_list(outputs, kernel_args) %}{# inline #} paddle::small_vector outputs { {%- for output in outputs %} {{output["name"] | to_opmaker_name_cstr}}{{", " if not loop.last}} {%- endfor %} } {%- endmacro %} {% macro get_expected_kernel(api) %} {% set kernel = api["kernel"] %} framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { {%if kernel["data_type"] is not none %}{# data type ---------------------------------#} {% if kernel["data_type"]["candidates"] | length == 1 %} {% set data_type_arg = kernel["data_type"]["candidates"][0] %} {% set inputs = api["inputs"] | map(attribute="name") | list %} {% if data_type_arg in inputs %} auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}}); {% else %}{# it is an attribute and probably named dtype#} auto data_type = framework::proto::VarType::Type(ctx.Attr("{{data_type_arg}}")); {% endif %} {% elif kernel["data_type"]["candidates"] | length == 2 %} {% set data_type_args = kernel["data_type"]["candidates"] %} auto data_type = framework::proto::VarType::Type(ctx.Attr("{{data_type_args[0]}}"); if (data_type == static_cast(-1)) { data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}}); } {% endif %} {% endif %} return framework::OpKernelType(data_type, ctx.GetPlace()); } {% endmacro %} {# --------------------------------------- operator ---------------------------------------------- #} {% macro operator(api) %} class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; {# ----------- get expected kernel type function -------------------------- #} {% set kernel = api["kernel"] %} {% if kernel["data_type"] is not none %} protected: {% filter indent(2, True)%} {{get_expected_kernel(api)}} {% endfilter %} {% endif %} }; DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}InferShapeFunctor, PD_INFER_META(phi::{{api["infer_meta"]["func"]}})); {# inplace inferer #} {% if api["inplace"] is not none %} {% set inplace_map %} {% for source, target in api["inplace"].items() %} {{"{"}}{{source | to_opmaker_name}}, {{target | to_opmaker_name}}{{"}"}}{{", " if not loop.last}} {%- endfor %} {%- endset %} DECLARE_INPLACE_OP_INFERER({{api["name"] | to_pascal_case}}InplaceInferer, {{inplace_map}}); {% endif %} {# no_need_buffer inferer #} {% if api["no_need_buffer"] is not none %} DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["name"] | to_pascal_case}}NoNeedBufferVarInferer, {{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}}); {% endif %} {% endmacro%} {% macro register_op_with_components(api) %} {% set name = api["name"] %} REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, {% if not "forward" in api %}{# it is a forward api #} ops::{{name | to_pascal_case}}OpMaker, {% endif %} {% if "backward" in api and api["backward"] is not none %}{# backward #} {% set backward_name = api["backward"] %} ops::{{backward_name | to_pascal_case}}OpMaker, ops::{{backward_name | to_pascal_case}}OpMaker, {% else %} paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, {% endif %} {% if api is supports_inplace %}{# inplace#} ops::{{name | to_pascal_case}}InplaceInferer, {% endif %} {% if api is supports_no_need_buffer %}{# no_need_buffer #} ops::{{name | to_pascal_case}}NoNeedBufferVarInferer, {% endif %} ops::{{name | to_pascal_case}}InferShapeFunctor); {% endmacro %} {% macro register_op_version(api) %} {% if "version" in api %} {% set name = api["name"] %} REGISTER_OP_VERSION({{name}}) {% for checkpoint in api["version"]%} .AddCheckpoint( R"ROC({{checkpoint["checkpoint"]}})ROC", paddle::framework::compatible::OpVersionDesc() {% for action in checkpoint["action"]%} {% if "add_input" in action %} .NewInput("{{action["add_input"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "delete_input" in action %} .DeleteInput("{{action["delete_input"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "modify_input" in action %} .ModifyInput("{{action["modify_input"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "add_output" in action %} .NewOutput("{{action["add_output"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "delete_output" in action %} .DeleteOutput("{{action["delete_output"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "modify_output" in action %} .ModifyOutput("{{action["modify_output"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "add_attr" in action %} .NewAttr("{{action["add_attr"]}}", "{{action["comment"]}}", {{action["default"]}}){{")" if loop.last}} {% endif %} {% if "delete_attr" in action %} .DeleteAttr("{{action["delete_attr"]}}", "{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% if "fix_bug" in action %} .BugfixWithBehaviorChanged("{{action["comment"]}}"){{")" if loop.last}} {% endif %} {% endfor %} {% endfor %}; {% endif %} {% endmacro %} {# --------------------------------------- backward op maker ---------------------------------------------- #} {% macro backward_op_maker(api, forward_api) %} {% set name = api["name"] %} {% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %} {% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %} {% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %} {% set forward_input_orig_names = forward_api["inputs"] | map(attribute="name") | list %} {% set forward_output_orig_names = forward_api["outputs"] | map(attribute="name") | list %} {% set forward_attr_orig_names = forward_api["attrs"] | map(attribute="name") | list %} template class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; protected: void Apply(GradOpPtr grad_op) const override { grad_op->SetType("{{name}}"); {% for input in api["inputs"] %} grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( input["name"], forward_input_names, forward_output_names, forward_input_orig_names, forward_output_orig_names)}}); {% endfor %} {% for output in api["outputs"] %} grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( output["name"], forward_input_names, forward_output_names, forward_input_orig_names, forward_output_orig_names)}}); {% endfor %} {% for attr in api["attrs"] %} {% set attr_name = attr["name"] %} {% if attr_name in forward_attr_names %} {% if attr["typename"] == "IntArray" %} grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); grad_op->SetInput("{{attr_name | to_pascal_case}}TensorList", this->Input("{{attr_name | to_pascal_case}}TensorList")); {% elif attr["typename"] == "Scalar" %} grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); {% endif %} grad_op->SetAttr("{{attr_name}}", this->GetAttr("{{forward_attr_orig_names[forward_attr_names.index(attr_name)]}}")); {% else %}{# maybe something wrong: backward op has more attrs than the forward one#} grad_op->AddAttr<{{attr["typename"] | to_op_attr_type}}>({{attr_name}}, "({{attr["typename"] | to_op_attr_type}}), exceptional attr {{attr_name}}"); grad_op->SetAttr("{{attr_name}}", {{process_default_value(attr)}}); {% endif %} {% endfor %} } }; {% endmacro %} {% macro extract_input_from_forward(name, input_names, output_names, input_orig_names, output_orig_names) %}{# inline #} {% if name in input_names %} {% set name_in_forward_orig = input_orig_names[input_names.index(name)]%} Input("{{name_in_forward_orig | to_pascal_case}}") {%- elif name in output_names %} {% set name_in_forward_orig = output_orig_names[output_names.index(name)]%} Output("{{name | to_pascal_case}}") {%- elif name.endswith("_grad") %}{# output grad#} {% set name_in_forward = name[:-5] %} {% if name_in_forward in output_names %} {% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %} OutputGrad("{{name_in_forward_orig | to_pascal_case}}") {%- endif %} {%- endif %} {%- endmacro %} {% macro extract_output_from_forward(name, input_names, output_names, input_orig_names, output_orig_names) %}{# inline #} {% if name[:-5] in input_names %} {% set name_in_forward = name[:-5] %} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} InputGrad("{{name[:-5] | to_pascal_case}}") {%- elif (name | to_input_name) in input_names %} {% set name_in_forward = name | to_input_name %} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} InputGrad("{{name | to_input_name | to_pascal_case}}") {%- endif %} {%- endmacro %} {% macro extract_attr_from_forward(name, attr_names, attr_origin_names) %} this->GetAttr("{{name}}") {%- endmacro %}