Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f7ecca45
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f7ecca45
编写于
7月 14, 2022
作者:
Z
zyfncg
提交者:
GitHub
7月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
supoort set original op_name for api (#44317)
上级
e8d78a70
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
103 addition
and
196 deletion
+103
-196
paddle/fluid/operators/diag_v2_op.cc
paddle/fluid/operators/diag_v2_op.cc
+0
-117
paddle/phi/api/yaml/api.yaml
paddle/phi/api/yaml/api.yaml
+9
-0
paddle/phi/api/yaml/api_compat.yaml
paddle/phi/api/yaml/api_compat.yaml
+8
-0
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+12
-0
paddle/phi/api/yaml/generator/generate_op.py
paddle/phi/api/yaml/generator/generate_op.py
+38
-27
paddle/phi/api/yaml/generator/templates/ks.c.j2
paddle/phi/api/yaml/generator/templates/ks.c.j2
+4
-1
paddle/phi/api/yaml/generator/templates/operator_utils.c.j2
paddle/phi/api/yaml/generator/templates/operator_utils.c.j2
+12
-9
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+0
-8
paddle/phi/kernels/diag_kernel.h
paddle/phi/kernels/diag_kernel.h
+20
-0
paddle/phi/ops/compat/diag_sig.cc
paddle/phi/ops/compat/diag_sig.cc
+0
-34
未找到文件。
paddle/fluid/operators/diag_v2_op.cc
已删除
100644 → 0
浏览文件 @
e8d78a70
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
class
DiagV2Op
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
};
class
DiagV2OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input tensor. Its shape is either 1-D or 2-D."
);
AddOutput
(
"Out"
,
"The output tensor. A square matrix or a vector."
);
AddAttr
<
int
>
(
"offset"
,
"The diagonal offset. A positive value represents "
"superdiagonal, 0 represents the main diagonal, and a "
"negative value represents subdiagonal."
)
.
SetDefault
(
0
);
AddAttr
<
float
>
(
"padding_value"
,
"Use this value to fill the area outside the specified "
"diagonal band. Only takes effect when the input is a 1-D "
"Tensor. The default value is 0."
)
.
SetDefault
(
0.0
f
);
AddComment
(
R"DOC(
If ``x`` is a vector (1-D tensor), a 2-D square tensor with the elements of ``x`` as the diagonal is returned.
If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal elements of ``x`` is returned.
The argument ``offset`` controls the diagonal offset:
If ``offset`` = 0, it is the main diagonal.
If ``offset`` > 0, it is superdiagonal.
If ``offset`` < 0, it is subdiagonal.
)DOC"
);
}
};
class
DiagV2GradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"X"
,
"X"
,
"DiagV2Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"DiagV2Grad"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
DiagV2GradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
grad_op
->
SetType
(
"diag_v2_grad"
);
grad_op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
this
->
Attrs
());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER
(
DiagGradV2NoNeedBufferVarsInferer
,
"X"
);
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
diag_v2
,
DiagInferShapeFunctor
,
PD_INFER_META
(
phi
::
DiagInferMeta
));
REGISTER_OPERATOR
(
diag_v2
,
ops
::
DiagV2Op
,
ops
::
DiagV2OpMaker
,
ops
::
DiagV2GradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
DiagV2GradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
DiagInferShapeFunctor
);
REGISTER_OPERATOR
(
diag_v2_grad
,
ops
::
DiagV2GradOp
,
ops
::
DiagGradV2NoNeedBufferVarsInferer
);
paddle/phi/api/yaml/api.yaml
浏览文件 @
f7ecca45
...
@@ -43,6 +43,15 @@
...
@@ -43,6 +43,15 @@
data_type
:
x
data_type
:
x
backward
:
cross_grad
backward
:
cross_grad
-
api
:
diag
args
:
(Tensor x, int offset = 0, float padding_value = 0.0)
output
:
Tensor
infer_meta
:
func
:
DiagInferMeta
kernel
:
func
:
diag
backward
:
diag_grad
-
api
:
diagonal
-
api
:
diagonal
args
:
(Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
args
:
(Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output
:
Tensor
output
:
Tensor
...
...
paddle/phi/api/yaml/api_compat.yaml
浏览文件 @
f7ecca45
...
@@ -12,6 +12,14 @@
...
@@ -12,6 +12,14 @@
outputs
:
outputs
:
out
:
Out
out
:
Out
-
api
:
diag
op_name
:
diag_v2
grad_op_name
:
diag_v2_grad
inputs
:
x
:
X
outputs
:
out
:
Out
-
api
:
diagonal
-
api
:
diagonal
inputs
:
inputs
:
x
:
Input
x
:
Input
...
...
paddle/phi/api/yaml/backward.yaml
浏览文件 @
f7ecca45
...
@@ -39,6 +39,18 @@
...
@@ -39,6 +39,18 @@
func
:
cross_grad
func
:
cross_grad
data_type
:
out_grad
data_type
:
out_grad
-
backward_api
:
diag_grad
forward
:
diag (Tensor x, int offset, float padding_value) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad, int offset)
output
:
Tensor(x_grad)
infer_meta
:
func
:
UnchangedInferMeta
param
:
[
x
]
kernel
:
func
:
diag_grad
data_type
:
out_grad
no_need_buffer
:
x
-
backward_api
:
diagonal_grad
-
backward_api
:
diagonal_grad
forward
:
diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
forward
:
diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1)
args
:
(Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1)
...
...
paddle/phi/api/yaml/generator/generate_op.py
浏览文件 @
f7ecca45
...
@@ -54,34 +54,21 @@ def restruct_io(api):
...
@@ -54,34 +54,21 @@ def restruct_io(api):
return
api
return
api
def
main
(
api_yaml_path
,
backward_yaml_path
,
api_compat_yaml_path
,
# replace name of op and params for OpMaker
api_version_yaml_path
,
output_op_path
,
output_arg_map_path
):
def
replace_compat_name
(
api_op_map
,
forward_api_dict
,
backward_api_dict
):
with
open
(
api_yaml_path
,
"rt"
)
as
f
:
for
api_args
in
api_op_map
:
apis
=
yaml
.
safe_load
(
f
)
apis
=
[
restruct_io
(
api
)
for
api
in
apis
]
forward_api_dict
=
to_named_dict
(
apis
)
with
open
(
backward_yaml_path
,
"rt"
)
as
f
:
backward_apis
=
yaml
.
safe_load
(
f
)
backward_apis
=
[
restruct_io
(
api
)
for
api
in
backward_apis
]
backward_api_dict
=
to_named_dict
(
backward_apis
)
with
open
(
api_version_yaml_path
,
"rt"
)
as
f
:
api_versions
=
yaml
.
safe_load
(
f
)
# add api version info into api
for
api_version
in
api_versions
:
forward_api_dict
[
api_version
[
'api'
]][
'version'
]
=
api_version
[
'version'
]
with
open
(
api_compat_yaml_path
,
"rt"
)
as
f
:
api_args_map
=
yaml
.
safe_load
(
f
)
# replace args name for OpMaker
for
api_args
in
api_args_map
:
if
api_args
[
'api'
]
not
in
forward_api_dict
:
if
api_args
[
'api'
]
not
in
forward_api_dict
:
continue
continue
forward_api_item
=
forward_api_dict
[
api_args
[
'api'
]]
forward_api_item
=
forward_api_dict
[
api_args
[
'api'
]]
has_backward
=
True
if
forward_api_item
[
'backward'
]
else
False
has_backward
=
True
if
forward_api_item
[
'backward'
]
else
False
if
has_backward
:
if
has_backward
:
backward_api_item
=
backward_api_dict
[
forward_api_item
[
'backward'
]]
backward_api_item
=
backward_api_dict
[
forward_api_item
[
'backward'
]]
if
'op_name'
in
api_args
:
forward_api_item
[
'op_name'
]
=
api_args
[
'op_name'
]
if
'grad_op_name'
in
api_args
and
has_backward
:
forward_api_item
[
'backward'
]
=
api_args
[
'grad_op_name'
]
backward_api_item
[
'op_name'
]
=
api_args
[
'grad_op_name'
]
key_set
=
[
'inputs'
,
'attrs'
,
'outputs'
]
key_set
=
[
'inputs'
,
'attrs'
,
'outputs'
]
args_map
=
{}
args_map
=
{}
for
key
in
key_set
:
for
key
in
key_set
:
...
@@ -175,6 +162,35 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
...
@@ -175,6 +162,35 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
for
param
in
backward_api_item
[
'no_need_buffer'
]
for
param
in
backward_api_item
[
'no_need_buffer'
]
]
]
def
main
(
api_yaml_path
,
backward_yaml_path
,
api_compat_yaml_path
,
api_version_yaml_path
,
output_op_path
,
output_arg_map_path
):
with
open
(
api_yaml_path
,
"rt"
)
as
f
:
apis
=
yaml
.
safe_load
(
f
)
apis
=
[
restruct_io
(
api
)
for
api
in
apis
]
forward_api_dict
=
to_named_dict
(
apis
)
with
open
(
backward_yaml_path
,
"rt"
)
as
f
:
backward_apis
=
yaml
.
safe_load
(
f
)
backward_apis
=
[
restruct_io
(
api
)
for
api
in
backward_apis
]
backward_api_dict
=
to_named_dict
(
backward_apis
)
with
open
(
api_version_yaml_path
,
"rt"
)
as
f
:
api_versions
=
yaml
.
safe_load
(
f
)
# add api version info into api
for
api_version
in
api_versions
:
forward_api_dict
[
api_version
[
'api'
]][
'version'
]
=
api_version
[
'version'
]
with
open
(
api_compat_yaml_path
,
"rt"
)
as
f
:
api_op_map
=
yaml
.
safe_load
(
f
)
for
api
in
apis
:
api
[
'op_name'
]
=
api
[
'name'
]
for
bw_api
in
backward_apis
:
bw_api
[
'op_name'
]
=
bw_api
[
'name'
]
replace_compat_name
(
api_op_map
,
forward_api_dict
,
backward_api_dict
)
# fill backward field for an api if another api claims it as forward
# fill backward field for an api if another api claims it as forward
for
name
,
backward_api
in
backward_api_dict
.
items
():
for
name
,
backward_api
in
backward_api_dict
.
items
():
forward_name
=
backward_api
[
"forward"
][
"name"
]
forward_name
=
backward_api
[
"forward"
][
"name"
]
...
@@ -183,11 +199,6 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
...
@@ -183,11 +199,6 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path,
if
forward_api
[
"backward"
]
is
None
:
if
forward_api
[
"backward"
]
is
None
:
forward_api
[
"backward"
]
=
name
forward_api
[
"backward"
]
=
name
if
forward_name
in
backward_api_dict
:
forward_api
=
backward_api_dict
[
forward_name
]
if
forward_api
[
"backward"
]
is
None
:
forward_api
[
"backward"
]
=
name
api_dict
=
{}
api_dict
=
{}
api_dict
.
update
(
forward_api_dict
)
api_dict
.
update
(
forward_api_dict
)
api_dict
.
update
(
backward_api_dict
)
api_dict
.
update
(
backward_api_dict
)
...
...
paddle/phi/api/yaml/generator/templates/ks.c.j2
浏览文件 @
f7ecca45
{% from "operator_utils.c.j2" import name_map, register_name_map %}
{% from "operator_utils.c.j2" import name_map, register_name_map
, register_base_kernel_name
%}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
#include "paddle/utils/small_vector.h"
...
@@ -18,6 +18,9 @@ namespace phi {
...
@@ -18,6 +18,9 @@ namespace phi {
} // namespace phi
} // namespace phi
{% for api in apis + backward_apis %}
{% for api in apis + backward_apis %}
{% if api["name"] != api["op_name"] %}
{{register_base_kernel_name(api)}}
{% endif %}
{% if api is base_api %}
{% if api is base_api %}
{{register_name_map(api)}}
{{register_name_map(api)}}
{% endif %}
{% endif %}
...
...
paddle/phi/api/yaml/generator/templates/operator_utils.c.j2
浏览文件 @
f7ecca45
{# ----------------------------- op maker ----------------------------------- #}
{# ----------------------------- op maker ----------------------------------- #}
{% macro op_maker(api) %}
{% macro op_maker(api) %}
{% set api_name = api["name"] %}
{% set api_name = api["
op_
name"] %}
class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
void Make() override {
...
@@ -124,9 +124,12 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg
...
@@ -124,9 +124,12 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg
*/
*/
{% endmacro %}
{% endmacro %}
{% macro register_base_kernel_name(api) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}});
{%- endmacro %}
{% macro register_name_map(api) %}
{% macro register_name_map(api) %}
PD_REGISTER_ARG_MAPPING_FN({{api["name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN({{api["
op_
name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %}
{%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #}
{% macro get_input_list(inputs, kernel_args) %}{# inline #}
...
@@ -196,7 +199,7 @@ framework::OpKernelType GetExpectedKernelType(
...
@@ -196,7 +199,7 @@ framework::OpKernelType GetExpectedKernelType(
{# --------------------------------------- operator ---------------------------------------------- #}
{# --------------------------------------- operator ---------------------------------------------- #}
{% macro operator(api) %}
{% macro operator(api) %}
class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
class {{api["
op_
name"] | to_pascal_case}}Op : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #}
{# ----------- get expected kernel type function -------------------------- #}
...
@@ -209,7 +212,7 @@ class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel
...
@@ -209,7 +212,7 @@ class {{api["name"] | to_pascal_case}}Op : public framework::OperatorWithKernel
{% endif %}
{% endif %}
};
};
DECLARE_INFER_SHAPE_FUNCTOR({{api["
name"]}}, {{api["
name"] | to_pascal_case}}InferShapeFunctor,
DECLARE_INFER_SHAPE_FUNCTOR({{api["
op_name"]}}, {{api["op_
name"] | to_pascal_case}}InferShapeFunctor,
PD_INFER_META(phi::{{api["infer_meta"]["func"]}}));
PD_INFER_META(phi::{{api["infer_meta"]["func"]}}));
{# inplace inferer #}
{# inplace inferer #}
{% if api["inplace"] is not none %}
{% if api["inplace"] is not none %}
...
@@ -218,19 +221,19 @@ DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}Inf
...
@@ -218,19 +221,19 @@ DECLARE_INFER_SHAPE_FUNCTOR({{api["name"]}}, {{api["name"] | to_pascal_case}}Inf
{{"{"}}{{source | to_opmaker_name}}, {{target | to_opmaker_name}}{{"}"}}{{", " if not loop.last}}
{{"{"}}{{source | to_opmaker_name}}, {{target | to_opmaker_name}}{{"}"}}{{", " if not loop.last}}
{%- endfor %}
{%- endfor %}
{%- endset %}
{%- endset %}
DECLARE_INPLACE_OP_INFERER({{api["name"] | to_pascal_case}}InplaceInferer,
DECLARE_INPLACE_OP_INFERER({{api["
op_
name"] | to_pascal_case}}InplaceInferer,
{{inplace_map}});
{{inplace_map}});
{% endif %}
{% endif %}
{# no_need_buffer inferer #}
{# no_need_buffer inferer #}
{% if api["no_need_buffer"] is not none %}
{% if api["no_need_buffer"] is not none %}
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["name"] | to_pascal_case}}NoNeedBufferVarInferer,
DECLARE_NO_NEED_BUFFER_VARS_INFERER({{api["
op_
name"] | to_pascal_case}}NoNeedBufferVarInferer,
{{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{{api["no_need_buffer"] | map("to_opmaker_name") | join(", ")}});
{% endif %}
{% endif %}
{% endmacro%}
{% endmacro%}
{% macro register_op_with_components(api) %}
{% macro register_op_with_components(api) %}
{% set name = api["name"] %}
{% set name = api["
op_
name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in api %}{# it is a forward api #}
{% if not "forward" in api %}{# it is a forward api #}
ops::{{name | to_pascal_case}}OpMaker,
ops::{{name | to_pascal_case}}OpMaker,
...
@@ -254,7 +257,7 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
...
@@ -254,7 +257,7 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% macro register_op_version(api) %}
{% macro register_op_version(api) %}
{% if "version" in api %}
{% if "version" in api %}
{% set name = api["name"] %}
{% set name = api["
op_
name"] %}
REGISTER_OP_VERSION({{name}})
REGISTER_OP_VERSION({{name}})
{% for checkpoint in api["version"]%}
{% for checkpoint in api["version"]%}
.AddCheckpoint(
.AddCheckpoint(
...
@@ -296,7 +299,7 @@ REGISTER_OP_VERSION({{name}})
...
@@ -296,7 +299,7 @@ REGISTER_OP_VERSION({{name}})
{# --------------------------------------- backward op maker ---------------------------------------------- #}
{# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(api, forward_api) %}
{% macro backward_op_maker(api, forward_api) %}
{% set name = api["name"] %}
{% set name = api["
op_
name"] %}
{% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_input_names = api["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_output_names = api["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_attr_names = api["forward"]["attrs"] | map(attribute="name") | list %}
...
...
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
f7ecca45
...
@@ -498,14 +498,6 @@
...
@@ -498,14 +498,6 @@
func
:
determinant
func
:
determinant
backward
:
det_grad
backward
:
det_grad
-
api
:
diag
args
:
(Tensor x, int offset, float padding_value)
output
:
Tensor
infer_meta
:
func
:
DiagInferMeta
kernel
:
func
:
diag
-
api
:
divide
-
api
:
divide
args
:
(Tensor x, Tensor y)
args
:
(Tensor x, Tensor y)
output
:
Tensor
output
:
Tensor
...
...
paddle/phi/kernels/diag_kernel.h
浏览文件 @
f7ecca45
...
@@ -18,6 +18,26 @@
...
@@ -18,6 +18,26 @@
namespace
phi
{
namespace
phi
{
/**
* @brief If ``x`` is a vector (1-D tensor), a 2-D square tensor with the
* elements of ``x`` as the diagonal is returned.
* If ``x`` is a matrix (2-D tensor), a 1-D tensor with the diagonal
* elements of ``x`` is returned.
*
* The argument ``offset`` controls the diagonal offset:
* If ``offset`` = 0, it is the main diagonal.
* If ``offset`` > 0, it is superdiagonal. If ``offset`` < 0,
* it is subdiagonal.
* @param ctx device context
* @param x The input tensor. Its shape is either 1-D or 2-D.
* @param offset The diagonal offset. A positive value represents
* superdiagonal, 0 represents the main diagonal, and a
* negative value represents subdiagonal.
* @param padding_value Use this value to fill the area outside the specified
* diagonal band. Only takes effect when the input is a
* 1-D Tensor. The default value is 0.
* @param out The output tensor. A square matrix or a vector.
*/
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
DiagKernel
(
const
Context
&
dev_ctx
,
void
DiagKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
...
paddle/phi/ops/compat/diag_sig.cc
已删除
100644 → 0
浏览文件 @
e8d78a70
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
DiagOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"diag"
,
{
"X"
},
{
"offset"
,
"padding_value"
},
{
"Out"
});
}
KernelSignature
DiagGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"diag_grad"
,
{
"X"
,
"Out@GRAD"
},
{
"offset"
},
{
"X@GRAD"
});
}
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
diag_v2
,
diag
);
PD_REGISTER_BASE_KERNEL_NAME
(
diag_v2_grad
,
diag_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
diag_v2
,
phi
::
DiagOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
diag_v2_grad
,
phi
::
DiagGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录