Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
93ff8e4c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
93ff8e4c
编写于
4月 19, 2023
作者:
W
Wang Xin
提交者:
GitHub
4月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add autogen code support for mean_all op (#52855)
* add autogen code support for mean_all op * bug fixed * bug fixed * bug fixed
上级
e5506be6
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
37 addition
and
156 deletion
+37
-156
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+0
-103
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+12
-0
paddle/phi/api/yaml/legacy_backward.yaml
paddle/phi/api/yaml/legacy_backward.yaml
+0
-10
paddle/phi/api/yaml/legacy_ops.yaml
paddle/phi/api/yaml/legacy_ops.yaml
+0
-9
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+7
-0
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+9
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+7
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+2
-0
paddle/phi/ops/compat/mean_sig.cc
paddle/phi/ops/compat/mean_sig.cc
+0
-34
未找到文件。
paddle/fluid/operators/mean_op.cc
已删除
100644 → 0
浏览文件 @
e5506be6
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace
paddle
{
namespace
operators
{
class
MeanOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
};
class
MeanOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input of mean op"
);
AddOutput
(
"Out"
,
"(Tensor) The output of mean op"
);
AddComment
(
R"DOC(
Mean Operator calculates the mean of all elements in X.
)DOC"
);
}
};
class
MeanOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
const
override
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
class
MeanGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
}
phi
::
KernelKey
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
return
phi
::
KernelKey
(
input_data_type
,
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
MeanGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
grad_op
->
SetType
(
"mean_grad"
);
grad_op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER
(
MeanGradNoNeedBufferVarsInferer
,
"X"
);
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
mean
,
MeanInferShapeFunctor
,
PD_INFER_META
(
phi
::
MeanAllInferMeta
));
REGISTER_OPERATOR
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
ops
::
MeanOpInferVarType
,
ops
::
MeanGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
MeanGradMaker
<
paddle
::
imperative
::
OpBase
>
,
MeanInferShapeFunctor
);
REGISTER_OPERATOR
(
mean_grad
,
ops
::
MeanGradOp
,
ops
::
MeanGradNoNeedBufferVarsInferer
);
paddle/phi/api/yaml/backward.yaml
浏览文件 @
93ff8e4c
...
@@ -1102,6 +1102,18 @@
...
@@ -1102,6 +1102,18 @@
kernel
:
kernel
:
func
:
maxout_grad
func
:
maxout_grad
-
backward_op
:
mean_all_grad
forward
:
mean_all(Tensor x) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad)
output
:
Tensor(x_grad)
infer_meta
:
func
:
UnchangedExceptLayoutInferMeta
param
:
[
x
]
kernel
:
func
:
mean_all_grad
data_type
:
out_grad
no_need_buffer
:
x
-
backward_op
:
memory_efficient_attention_grad
-
backward_op
:
memory_efficient_attention_grad
forward
:
memory_efficient_attention (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test) -> Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
forward
:
memory_efficient_attention (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test) -> Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
args
:
(Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor output, Tensor logsumexp, Tensor seed_and_offset, Tensor output_grad, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale)
args
:
(Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor output, Tensor logsumexp, Tensor seed_and_offset, Tensor output_grad, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale)
...
...
paddle/phi/api/yaml/legacy_backward.yaml
浏览文件 @
93ff8e4c
...
@@ -586,16 +586,6 @@
...
@@ -586,16 +586,6 @@
func
:
maximum_grad
func
:
maximum_grad
composite
:
maximum_grad(x, y, out_grad, axis, x_grad, y_grad)
composite
:
maximum_grad(x, y, out_grad, axis, x_grad, y_grad)
-
backward_op
:
mean_all_grad
forward
:
mean_all(Tensor x) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad)
output
:
Tensor(x_grad)
infer_meta
:
func
:
UnchangedInferMeta
param
:
[
x
]
kernel
:
func
:
mean_all_grad
-
backward_op
:
mean_double_grad
-
backward_op
:
mean_double_grad
forward
:
mean_grad (Tensor x, Tensor grad_out, IntArray axis={}, bool keepdim=false, bool reduce_all =
false
) -> Tensor(grad_x)
forward
:
mean_grad (Tensor x, Tensor grad_out, IntArray axis={}, bool keepdim=false, bool reduce_all =
false
) -> Tensor(grad_x)
args
:
(Tensor grad_x_grad, IntArray axis={}, bool keepdim=false)
args
:
(Tensor grad_x_grad, IntArray axis={}, bool keepdim=false)
...
...
paddle/phi/api/yaml/legacy_ops.yaml
浏览文件 @
93ff8e4c
...
@@ -816,15 +816,6 @@
...
@@ -816,15 +816,6 @@
func
:
mean
func
:
mean
backward
:
mean_grad
backward
:
mean_grad
-
op
:
mean_all
args
:
(Tensor x)
output
:
Tensor
infer_meta
:
func
:
MeanAllInferMeta
kernel
:
func
:
mean_all
backward
:
mean_all_grad
-
op
:
merged_adam_
-
op
:
merged_adam_
args
:
(Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1, Scalar beta2, Scalar epsilon, bool multi_precision, bool use_global_beta_pow)
args
:
(Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1, Scalar beta2, Scalar epsilon, bool multi_precision, bool use_global_beta_pow)
output
:
Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()}
output
:
Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()}
...
...
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
93ff8e4c
...
@@ -1441,6 +1441,13 @@
...
@@ -1441,6 +1441,13 @@
extra
:
extra
:
attrs
:
[
bool use_mkldnn = false
]
attrs
:
[
bool use_mkldnn = false
]
-
op
:
mean_all (mean)
backward
:
mean_all_grad (mean_grad)
inputs
:
x
:
X
outputs
:
out
:
Out
-
op
:
merge_selected_rows
-
op
:
merge_selected_rows
inputs
:
inputs
:
x
:
X
x
:
X
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
93ff8e4c
...
@@ -1227,6 +1227,15 @@
...
@@ -1227,6 +1227,15 @@
func
:
maxout
func
:
maxout
backward
:
maxout_grad
backward
:
maxout_grad
-
op
:
mean_all
args
:
(Tensor x)
output
:
Tensor
infer_meta
:
func
:
MeanAllInferMeta
kernel
:
func
:
mean_all
backward
:
mean_all_grad
-
op
:
memory_efficient_attention
-
op
:
memory_efficient_attention
args
:
(Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test)
args
:
(Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test)
output
:
Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
output
:
Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
93ff8e4c
...
@@ -4436,6 +4436,13 @@ void TriuInferMeta(const MetaTensor& x, int diagonal, MetaTensor* out) {
...
@@ -4436,6 +4436,13 @@ void TriuInferMeta(const MetaTensor& x, int diagonal, MetaTensor* out) {
TrilTriuInferMeta
(
x
,
diagonal
,
false
,
out
);
TrilTriuInferMeta
(
x
,
diagonal
,
false
,
out
);
}
}
// Some operator having oneDnn kernel will be set layout in kernel.
void
UnchangedExceptLayoutInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
set_dtype
(
x
.
dtype
());
out
->
share_lod
(
x
);
}
void
UnchangedInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
)
{
void
UnchangedInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
)
{
out
->
share_meta
(
x
);
out
->
share_meta
(
x
);
}
}
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
93ff8e4c
...
@@ -628,6 +628,8 @@ void UnbindInferMeta(const MetaTensor& x,
...
@@ -628,6 +628,8 @@ void UnbindInferMeta(const MetaTensor& x,
int
axis
,
int
axis
,
std
::
vector
<
MetaTensor
*>
outs
);
std
::
vector
<
MetaTensor
*>
outs
);
void
UnchangedExceptLayoutInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
);
void
UnchangedInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
);
void
UnchangedInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
);
// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1]
// meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1]
...
...
paddle/phi/ops/compat/mean_sig.cc
已删除
100644 → 0
浏览文件 @
e5506be6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
MeanOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"mean_all"
,
{
"X"
},
{},
{
"Out"
});
}
KernelSignature
MeanGradOpGradArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"mean_all_grad"
,
{
"X"
,
"Out@GRAD"
},
{},
{
"X@GRAD"
});
}
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
mean
,
mean_all
);
PD_REGISTER_BASE_KERNEL_NAME
(
mean_grad
,
mean_all_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
mean
,
phi
::
MeanOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
mean_grad
,
phi
::
MeanGradOpGradArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录