Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36eb5cde
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36eb5cde
编写于
7月 14, 2023
作者:
R
RedContritio
提交者:
GitHub
7月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support auto generate for static op elementwise_min (#55008)
上级
36b2c5e5
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
67 addition
and
209 deletion
+67
-209
paddle/fluid/operators/elementwise/elementwise_min_op.cc
paddle/fluid/operators/elementwise/elementwise_min_op.cc
+0
-168
paddle/fluid/operators/elementwise/unity_build_rule.cmake
paddle/fluid/operators/elementwise/unity_build_rule.cmake
+2
-3
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+11
-0
paddle/phi/api/yaml/legacy_backward.yaml
paddle/phi/api/yaml/legacy_backward.yaml
+0
-10
paddle/phi/api/yaml/legacy_ops.yaml
paddle/phi/api/yaml/legacy_ops.yaml
+0
-10
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+12
-0
paddle/phi/api/yaml/op_version.yaml
paddle/phi/api/yaml/op_version.yaml
+8
-0
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+10
-0
paddle/phi/api/yaml/static_backward.yaml
paddle/phi/api/yaml/static_backward.yaml
+12
-0
paddle/phi/api/yaml/static_ops.yaml
paddle/phi/api/yaml/static_ops.yaml
+9
-0
paddle/phi/ops/compat/elementwise_sig.cc
paddle/phi/ops/compat/elementwise_sig.cc
+3
-18
未找到文件。
paddle/fluid/operators/elementwise/elementwise_min_op.cc
已删除
100644 → 0
浏览文件 @
36b2c5e5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace
paddle
{
namespace
framework
{
class
OpDesc
;
}
// namespace framework
namespace
imperative
{
class
OpBase
;
}
// namespace imperative
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
ElementwiseMinOpMaker
:
public
ElementwiseOpMaker
{
protected:
std
::
string
GetName
()
const
override
{
return
"Min"
;
}
std
::
string
GetEquation
()
const
override
{
return
"Out = min(X, Y)"
;
}
void
AddInputX
()
override
{
AddInput
(
"X"
,
"The first tensor holding the elements to be compared."
);
}
void
AddInputY
()
override
{
AddInput
(
"Y"
,
"The second tensor holding the elements to be compared."
);
}
std
::
string
GetOpFunctionality
()
const
override
{
return
"Compare two tensors and returns a new tensor containing the "
"element-wise minima."
;
}
};
class
ElementwiseFMinOpMaker
:
public
ElementwiseOpMaker
{
protected:
std
::
string
GetName
()
const
override
{
return
"FMin"
;
}
std
::
string
GetEquation
()
const
override
{
return
"Out = fmin(X, Y)"
;
}
void
AddInputX
()
override
{
AddInput
(
"X"
,
"The first tensor holding the elements to be compared."
);
}
void
AddInputY
()
override
{
AddInput
(
"Y"
,
"The second tensor holding the elements to be compared."
);
}
std
::
string
GetOpFunctionality
()
const
override
{
return
"Compare two tensors and returns a new tensor containing the "
"element-wise minima. If the element of one tensor is nan, "
"return the element value of the other tensor, if both are nan, "
"return the first nan"
;
}
};
class
ElementwiseMinCompositeGradOpMaker
:
public
prim
::
CompositeGradOpMakerBase
{
using
prim
::
CompositeGradOpMakerBase
::
CompositeGradOpMakerBase
;
public:
void
Apply
()
override
{
paddle
::
Tensor
x
=
this
->
GetSingleForwardInput
(
"X"
);
paddle
::
Tensor
y
=
this
->
GetSingleForwardInput
(
"Y"
);
paddle
::
Tensor
out_grad
=
this
->
GetSingleOutputGrad
(
"Out"
);
paddle
::
Tensor
dx
=
this
->
GetSingleInputGrad
(
"X"
);
auto
*
dx_ptr
=
this
->
GetOutputPtr
(
&
dx
);
std
::
string
dx_name
=
this
->
GetOutputName
(
dx
);
paddle
::
Tensor
dy
=
this
->
GetSingleInputGrad
(
"Y"
);
auto
*
dy_ptr
=
this
->
GetOutputPtr
(
&
dy
);
std
::
string
dy_name
=
this
->
GetOutputName
(
dy
);
VLOG
(
6
)
<<
"Runing minimum_grad composite func"
;
int
axis
=
static_cast
<
int
>
(
this
->
Attr
<
int
>
(
"axis"
));
PADDLE_ENFORCE_EQ
(
axis
,
-
1
,
phi
::
errors
::
InvalidArgument
(
"We only support axis = -1 in composite minimum_grad but we got: "
,
axis
));
prim
::
minimum_grad
<
prim
::
DescTensor
>
(
x
,
y
,
out_grad
,
dx_ptr
,
dy_ptr
);
this
->
RecoverOutputName
(
dx
,
dx_name
);
this
->
RecoverOutputName
(
dy
,
dy_name
);
}
};
template
<
typename
T
>
class
ElementwiseMinGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"elementwise_min_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
};
template
<
typename
T
>
class
ElementwiseFMinGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"elementwise_fmin_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
elementwise_min
,
ops
::
ElementwiseOp
,
ops
::
ElementwiseMinOpMaker
,
ops
::
ElementwiseOpInferVarType
,
ops
::
ElementwiseMinGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ElementwiseMinGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
ElementwiseMinCompositeGradOpMaker
);
REGISTER_OPERATOR
(
elementwise_min_grad
,
ops
::
ElementwiseOpGrad
);
REGISTER_OP_VERSION
(
elementwise_min
)
.
AddCheckpoint
(
R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
().
NewAttr
(
"Scale_y"
,
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_min."
,
1.0
f
));
REGISTER_OPERATOR
(
elementwise_fmin
,
ops
::
ElementwiseOp
,
ops
::
ElementwiseFMinOpMaker
,
ops
::
ElementwiseOpInferVarType
,
ops
::
ElementwiseFMinGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ElementwiseFMinGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
elementwise_fmin_grad
,
ops
::
ElementwiseOpGrad
);
paddle/fluid/operators/elementwise/unity_build_rule.cmake
浏览文件 @
36eb5cde
...
...
@@ -4,9 +4,8 @@
# Generally, the combination rules in this file do not need to be modified.
# If there are some redefined error in compiling with the source file which
# in combination rule, you can remove the source file from the following rules.
register_unity_group
(
cc elementwise_add_op.cc elementwise_div_op.cc elementwise_min_op.cc
elementwise_mul_op.cc elementwise_sub_op.cc
)
register_unity_group
(
cc elementwise_add_op.cc elementwise_div_op.cc
elementwise_mul_op.cc elementwise_sub_op.cc
)
register_unity_group
(
cu
elementwise_add_op.cu
...
...
paddle/phi/api/yaml/backward.yaml
浏览文件 @
36eb5cde
...
...
@@ -870,6 +870,17 @@
func
:
fmax_grad
data_type
:
out_grad
-
backward_op
:
fmin_grad
forward
:
fmin(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
fmin_grad
data_type
:
out_grad
-
backward_op
:
fold_grad
forward
:
fold (Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
...
...
paddle/phi/api/yaml/legacy_backward.yaml
浏览文件 @
36eb5cde
...
...
@@ -232,16 +232,6 @@
func
:
UnchangedInferMeta
invoke
:
zeros_like(out_grad)
-
backward_op
:
fmin_grad
forward
:
fmin(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
fmin_grad
-
backward_op
:
frobenius_norm_grad
forward
:
frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args
:
(Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all)
...
...
paddle/phi/api/yaml/legacy_ops.yaml
浏览文件 @
36eb5cde
...
...
@@ -326,16 +326,6 @@
kernel
:
func
:
floor_divide
-
op
:
fmin
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
infer_meta
:
param
:
[
x
,
y
]
func
:
ElementwiseInferMeta
kernel
:
func
:
fmin
backward
:
fmin_grad
-
op
:
frobenius_norm
args
:
(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
output
:
Tensor(out)
...
...
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
36eb5cde
...
...
@@ -1107,9 +1107,15 @@
-
op
:
fmin (elementwise_fmin)
backward
:
fmin_grad (elementwise_fmin_grad)
inputs
:
{
x
:
X
,
y
:
Y
}
outputs
:
{
out
:
Out
}
extra
:
attrs
:
[
bool use_mkldnn = false
,
str mkldnn_data_type = "float32"
,
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
complex_promote
:
[
X
,
Y
]
manual_signature
:
[
fmin
]
-
op
:
fold
inputs
:
...
...
@@ -1839,9 +1845,15 @@
-
op
:
minimum (elementwise_min)
backward
:
minimum_grad (elementwise_min_grad)
inputs
:
{
x
:
X
,
y
:
Y
}
outputs
:
{
out
:
Out
}
extra
:
attrs
:
[
bool use_mkldnn = false
,
str x_data_format = ""
,
str y_data_format = ""
,
str mkldnn_data_type = "float32"
,
bool use_quantizer = false
,
float Scale_x = 1.0f
,
float Scale_y = 1.0f
,
float Scale_out = 1.0f
]
complex_promote
:
[
X
,
Y
]
manual_signature
:
[
minimum
]
-
op
:
mish
backward
:
mish_grad
...
...
paddle/phi/api/yaml/op_version.yaml
浏览文件 @
36eb5cde
...
...
@@ -181,6 +181,14 @@
comment
:
In order to support the function of scaling the input Y when using the operator of elementwise_max.
default
:
1.0
-
op
:
elementwise_min
version
:
-
checkpoint
:
Register elementwise_min for adding the attribute of Scale_y.
action
:
-
add_attr
:
Scale_y
comment
:
In order to support the function of scaling the input Y when using the operator of elementwise_min.
default
:
1.0
-
op
:
elementwise_mod
version
:
-
checkpoint
:
Register elementwise_mod for adding the attribute of Scale_y
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
36eb5cde
...
...
@@ -941,6 +941,16 @@
func
:
fmax
backward
:
fmax_grad
-
op
:
fmin
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
infer_meta
:
func
:
ElementwiseInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
fmin
backward
:
fmin_grad
-
op
:
fold
args
:
(Tensor x, int[] output_sizes, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)
output
:
Tensor(out)
...
...
paddle/phi/api/yaml/static_backward.yaml
浏览文件 @
36eb5cde
...
...
@@ -211,6 +211,18 @@
kernel
:
func
:
min_grad
-
backward_op
:
minimum_grad
forward
:
minimum(Tensor x, Tensor y, int axis = -1) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
y
]
kernel
:
func
:
minimum_grad
data_type
:
out_grad
composite
:
minimum_grad(x, y, out_grad, x_grad, y_grad)
-
backward_op
:
norm_grad
forward
:
norm (Tensor x, int axis, float epsilon=1.0e-10f, bool is_test=false) -> Tensor(out), Tensor(norm)
args
:
(Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test)
...
...
paddle/phi/api/yaml/static_ops.yaml
浏览文件 @
36eb5cde
...
...
@@ -407,6 +407,15 @@
param
:
[
x
,
axis
,
keepdim
,
reduce_all
]
backward
:
min_grad
-
op
:
minimum
args
:
(Tensor x, Tensor y, int axis = -1)
output
:
Tensor(out)
infer_meta
:
func
:
ElementwiseRawInferMeta
kernel
:
func
:
minimum
backward
:
minimum_grad
-
op
:
norm
args
:
(Tensor x, int axis, float epsilon=1.0e-10f, bool is_test=false)
output
:
Tensor(out), Tensor(norm)
...
...
paddle/phi/ops/compat/elementwise_sig.cc
浏览文件 @
36eb5cde
...
...
@@ -78,6 +78,9 @@ KernelSignature ElementwiseMaxOpArgumentMapping(
KernelSignature
ElementwiseMinOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
IsForInferShape
())
{
return
KernelSignature
(
"minimum_raw"
,
{
"X"
,
"Y"
},
{
"axis"
},
{
"Out"
});
}
int
axis
=
paddle
::
any_cast
<
int
>
(
ctx
.
Attr
(
"axis"
));
if
(
axis
==
-
1
)
{
return
KernelSignature
(
"minimum"
,
{
"X"
,
"Y"
},
{},
{
"Out"
});
...
...
@@ -162,12 +165,6 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
{
"X@GRAD"
,
"Y@GRAD"
});
}
KernelSignature
ElementwiseFMinGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"fmin_grad"
,
{
"X"
,
"Y"
,
"Out@GRAD"
},
{},
{
"X@GRAD"
,
"Y@GRAD"
});
}
KernelSignature
ElementwiseDivDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"divide_double_grad"
,
...
...
@@ -209,12 +206,6 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
{
"D_X"
,
"D_Y"
,
"D_DOut"
,
"D_DDX"
,
"D_DDY"
});
}
KernelSignature
ElementwiseMinGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"minimum_grad"
,
{
"X"
,
"Y"
,
"Out@GRAD"
},
{},
{
"X@GRAD"
,
"Y@GRAD"
});
}
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_add
,
add
);
...
...
@@ -237,8 +228,6 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_mul_triple_grad
,
multiply_triple_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_fmax
,
fmax
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_fmin
,
fmin
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_fmin_grad
,
fmin_grad
);
PD_REGISTER_BASE_KERNEL_NAME
(
elementwise_min_grad
,
minimum_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_add
,
phi
::
ElementwiseAddOpArgumentMapping
);
...
...
@@ -282,8 +271,4 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
phi
::
ElementwiseFMaxOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_fmin
,
phi
::
ElementwiseFMinOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_fmin_grad
,
phi
::
ElementwiseFMinGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
elementwise_min_grad
,
phi
::
ElementwiseMinGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
grad_add
,
phi
::
ElementwiseGradAddOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录