未验证 提交 f7ecca45 编写于 作者: Z zyfncg 提交者: GitHub

supoort set original op_name for api (#44317)

上级 e8d78a70
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
class DiagV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor. Its shape is either 1-D or 2-D.");
AddOutput("Out", "The output tensor. A square matrix or a vector.");
AddAttr<int>("offset",
"The diagonal offset. A positive value represents "
"superdiagonal, 0 represents the main diagonal, and a "
"negative value represents subdiagonal.")
.SetDefault(0);
AddAttr<float>("padding_value",
"Use this value to fill the area outside the specified "
"diagonal band. Only takes effect when the input is a 1-D "
"Tensor. The default value is 0.")
.SetDefault(0.0f);
AddComment(R"DOC(
If ``x`` is a vector (1-D tensor), a 2-D square tensor with the elements of ``x`` as the diagonal is returned.
If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned.
The argument ``offset`` controls the diagonal offset:
If ``offset`` = 0, it is the main diagonal.
If ``offset`` > 0, it is superdiagonal.
If ``offset`` < 0, it is subdiagonal.
)DOC");
}
};
class DiagV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "X", "X", "DiagV2Grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
framework::GradVarName("X"),
"DiagV2Grad");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class DiagV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("diag_v2_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagGradV2NoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(diag_v2,
DiagInferShapeFunctor,
PD_INFER_META(phi::DiagInferMeta));
REGISTER_OPERATOR(diag_v2,
ops::DiagV2Op,
ops::DiagV2OpMaker,
ops::DiagV2GradOpMaker<paddle::framework::OpDesc>,
ops::DiagV2GradOpMaker<paddle::imperative::OpBase>,
DiagInferShapeFunctor);
REGISTER_OPERATOR(diag_v2_grad,
ops::DiagV2GradOp,
ops::DiagGradV2NoNeedBufferVarsInferer);
...@@ -43,6 +43,15 @@ ...@@ -43,6 +43,15 @@
data_type : x data_type : x
backward : cross_grad backward : cross_grad
- api : diag
args : (Tensor x, int offset = 0, float padding_value = 0.0)
output : Tensor
infer_meta :
func : DiagInferMeta
kernel :
func : diag
backward : diag_grad
- api : diagonal - api : diagonal
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1) args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor output : Tensor
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
outputs : outputs :
out : Out out : Out
- api : diag
op_name : diag_v2
grad_op_name : diag_v2_grad
inputs :
x : X
outputs :
out : Out
- api : diagonal - api : diagonal
inputs : inputs :
x : Input x : Input
......
...@@ -39,6 +39,18 @@ ...@@ -39,6 +39,18 @@
func : cross_grad func : cross_grad
data_type : out_grad data_type : out_grad
- backward_api : diag_grad
forward : diag (Tensor x, int offset, float padding_value) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : diag_grad
data_type : out_grad
no_need_buffer : x
- backward_api : diagonal_grad - backward_api : diagonal_grad
forward : diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) forward : diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1) args : (Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1)
......
...@@ -54,34 +54,21 @@ def restruct_io(api): ...@@ -54,34 +54,21 @@ def restruct_io(api):
return api return api
def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path, # replace name of op and params for OpMaker
api_version_yaml_path, output_op_path, output_arg_map_path): def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
with open(api_yaml_path, "rt") as f: for api_args in api_op_map:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
with open(api_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
forward_api_dict[api_version['api']]['version'] = api_version['version']
with open(api_compat_yaml_path, "rt") as f:
api_args_map = yaml.safe_load(f)
# replace args name for OpMaker
for api_args in api_args_map:
if api_args['api'] not in forward_api_dict: if api_args['api'] not in forward_api_dict:
continue continue
forward_api_item = forward_api_dict[api_args['api']] forward_api_item = forward_api_dict[api_args['api']]
has_backward = True if forward_api_item['backward'] else False has_backward = True if forward_api_item['backward'] else False
if has_backward: if has_backward:
backward_api_item = backward_api_dict[forward_api_item['backward']] backward_api_item = backward_api_dict[forward_api_item['backward']]
if 'op_name' in api_args:
forward_api_item['op_name'] = api_args['op_name']
if 'grad_op_name' in api_args and has_backward:
forward_api_item['backward'] = api_args['grad_op_name']
backward_api_item['op_name'] = api_args['grad_op_name']
key_set = ['inputs', 'attrs', 'outputs'] key_set = ['inputs', 'attrs', 'outputs']
args_map = {} args_map = {}
for key in key_set: for key in key_set:
...@@ -175,6 +162,35 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path, ...@@ -175,6 +162,35 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
for param in backward_api_item['no_need_buffer'] for param in backward_api_item['no_need_buffer']
] ]
def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
api_version_yaml_path, output_op_path, output_arg_map_path):
with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
with open(api_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
forward_api_dict[api_version['api']]['version'] = api_version['version']
with open(api_compat_yaml_path, "rt") as f:
api_op_map = yaml.safe_load(f)
for api in apis:
api['op_name'] = api['name']
for bw_api in backward_apis:
bw_api['op_name'] = bw_api['name']
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict)
# fill backward field for an api if another api claims it as forward # fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items(): for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"] forward_name = backward_api["forward"]["name"]
...@@ -183,11 +199,6 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path, ...@@ -183,11 +199,6 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
if forward_api["backward"] is None: if forward_api["backward"] is None:
forward_api["backward"] = name forward_api["backward"] = name
if forward_name in backward_api_dict:
forward_api = backward_api_dict[forward_name]
if forward_api["backward"] is None:
forward_api["backward"] = name
api_dict = {} api_dict = {}
api_dict.update(forward_api_dict) api_dict.update(forward_api_dict)
api_dict.update(backward_api_dict) api_dict.update(backward_api_dict)
......
{% from "operator_utils.c.j2" import name_map, register_name_map %} {% from "operator_utils.c.j2" import name_map, register_name_map, register_base_kernel_name %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. // this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
...@@ -18,6 +18,9 @@ namespace phi { ...@@ -18,6 +18,9 @@ namespace phi {
} // namespace phi } // namespace phi
{% for api in apis + backward_apis %} {% for api in apis + backward_apis %}
{% if api["name"] != api["op_name"] %}
{{register_base_kernel_name(api)}}
{% endif %}
{% if api is base_api %} {% if api is base_api %}
{{register_name_map(api)}} {{register_name_map(api)}}
{% endif %} {% endif %}
......
{# ----------------------------- op maker ----------------------------------- #} {# ----------------------------- op maker ----------------------------------- #}
{% macro op_maker(api) %} {% macro op_maker(api) %}
{% set api_name = api["name"] %} {% set api_name = api["op_name"] %}
class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker { class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -124,9 +124,12 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg ...@@ -124,9 +124,12 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg
*/ */
{% endmacro %} {% endmacro %}
{% macro register_base_kernel_name(api) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}});
{%- endmacro %}
{% macro register_name_map(api) %} {% macro register_name_map(api) %}
PD_REGISTER_ARG_MAPPING_FN({{api["name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %} {%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #} {% macro get_input_list(inputs, kernel_args) %}{# inline #}
...@@ -196,7 +199,7 @@ framework::OpKernelType GetExpectedKernelType( ...@@ -196,7 +199,7 @@ framework::OpKernelType GetExpectedKernelType(
{# --------------------------------------- operator ---------------------------------------------- #} {# --------------------------------------- operator ---------------------------------------------- #}
{% macro operator(api) %} {% macro operator(api) %}
class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel { class {{api["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #} {# ----------- get expected kernel type function -------------------------- #}
...@@ -209,7 +212,7 @@ class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel ...@@ -209,7 +212,7 @@ class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel
{% endif %} {% endif %}
}; };
DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}InferShapeFunctor, DECLARE_INFER_SHAPE_FUNCTOR({{api["op_name"]}}, {{api["op_name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{api["infer_meta"]["func"]}})); PD_INFER_META(phi::{{api["infer_meta"]["func"]}}));
{# inplace inferer #} {# inplace inferer #}
{% if api["inplace"] is not none %} {% if api["inplace"] is not none %}
...@@ -218,19 +221,19 @@ DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}Inf ...@@ -218,19 +221,19 @@ DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}Inf
{{"{"}}{{source | to_opmaker_name}}, {{target | to_opmaker_name}}{{"}"}}{{", " if not loop.last}} {{"{"}}{{source | to_opmaker_name}}, {{target | to_opmaker_name}}{{"}"}}{{", " if not loop.last}}
{%- endfor %} {%- endfor %}
{%- endset %} {%- endset %}
DECLARE_INPLACE_OP_INFERER({{api["name"] | to_pascal_case}}InplaceInferer, DECLARE_INPLACE_OP_INFERER({{api["op_name"] | to_pascal_case}}InplaceInferer,
{{inplace_map}}); {{inplace_map}});
{% endif %} {% endif %}
{# no_need_buffer inferer #} {# no_need_buffer inferer #}
{% if api["no_need_buffer"] is not none %} {% if api["no_need_buffer"] is not none %}
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["name"] | to_pascal_case}}NoNeedBufferVarInferer, DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["op_name"] | to_pascal_case}}NoNeedBufferVarInferer,
{{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}}); {{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{% endif %} {% endif %}
{% endmacro%} {% endmacro%}
{% macro register_op_with_components(api) %} {% macro register_op_with_components(api) %}
{% set name = api["name"] %} {% set name = api["op_name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in api %}{# it is a forward api #} {% if not "forward" in api %}{# it is a forward api #}
ops::{{name | to_pascal_case}}OpMaker, ops::{{name | to_pascal_case}}OpMaker,
...@@ -254,7 +257,7 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, ...@@ -254,7 +257,7 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% macro register_op_version(api) %} {% macro register_op_version(api) %}
{% if "version" in api %} {% if "version" in api %}
{% set name = api["name"] %} {% set name = api["op_name"] %}
REGISTER_OP_VERSION({{name}}) REGISTER_OP_VERSION({{name}})
{% for checkpoint in api["version"]%} {% for checkpoint in api["version"]%}
.AddCheckpoint( .AddCheckpoint(
...@@ -296,7 +299,7 @@ REGISTER_OP_VERSION({{name}}) ...@@ -296,7 +299,7 @@ REGISTER_OP_VERSION({{name}})
{# --------------------------------------- backward op maker ---------------------------------------------- #} {# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(api, forward_api) %} {% macro backward_op_maker(api, forward_api) %}
{% set name = api["name"] %} {% set name = api["op_name"] %}
{% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %} {% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %} {% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %} {% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %}
......
...@@ -498,14 +498,6 @@ ...@@ -498,14 +498,6 @@
func : determinant func : determinant
backward : det_grad backward : det_grad
- api : diag
args : (Tensor x, int offset, float padding_value)
output : Tensor
infer_meta :
func : DiagInferMeta
kernel :
func : diag
- api : divide - api : divide
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor
......
...@@ -18,6 +18,26 @@ ...@@ -18,6 +18,26 @@
namespace phi { namespace phi {
/**
* @brief If ``x`` is a vector (1-D tensor), a 2-D square tensor with the
* elements of ``x`` as the diagonal is returned.
* If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal
* elements of ``x`` is returned.
*
* The argument ``offset`` controls the diagonal offset:
* If ``offset`` = 0, it is the main diagonal.
* If ``offset`` > 0, it is superdiagonal. If ``offset`` < 0,
* it is subdiagonal.
* @param ctx device context
* @param x The input tensor. Its shape is either 1-D or 2-D.
* @param offset The diagonal offset. A positive value represents
* superdiagonal, 0 represents the main diagonal, and a
* negative value represents subdiagonal.
* @param padding_value Use this value to fill the area outside the specified
* diagonal band. Only takes effect when the input is a
* 1-D Tensor. The default value is 0.
* @param out The output tensor. A square matrix or a vector.
*/
template <typename T, typename Context> template <typename T, typename Context>
void DiagKernel(const Context& dev_ctx, void DiagKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("diag", {"X"}, {"offset", "padding_value"}, {"Out"});
}
KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"diag_grad", {"X", "Out@GRAD"}, {"offset"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(diag_v2, diag);
PD_REGISTER_BASE_KERNEL_NAME(diag_v2_grad, diag_grad);
PD_REGISTER_ARG_MAPPING_FN(diag_v2, phi::DiagOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(diag_v2_grad, phi::DiagGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册