Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
42a75145
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
42a75145
编写于
2月 10, 2023
作者:
A
Aurelius84
提交者:
GitHub
2月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix inferMefer in transpose2_grad (#50388)
* Fix inferMefer in transpose2_grad * fix infershape * fix unittest
上级
ca520280
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
15 addition
and
30 deletion
+15
-30
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+15
-30
未找到文件。
paddle/fluid/operators/transpose_op.cc
浏览文件 @
42a75145
...
...
@@ -16,6 +16,10 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
...
...
@@ -179,19 +183,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"TransposeOpGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"TransposeOpGrad"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
}
protected:
phi
::
KernelKey
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -320,21 +311,6 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"XShape"
),
"Input"
,
"XShape"
,
"Transpose2OpGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"Transpose2OpGrad"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
auto
xshape_dim
=
ctx
->
GetInputDim
(
"XShape"
);
auto
x_shape_dim
=
phi
::
slice_ddim
(
xshape_dim
,
1
,
xshape_dim
.
size
());
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_shape_dim
);
ctx
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
}
}
protected:
phi
::
KernelKey
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -359,6 +335,13 @@ class TransposeGradInferVarType : public framework::VarTypeInference {
}
// namespace operators
}
// namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR
(
transpose_grad
,
TransposeGradInferShapeFunctor
,
PD_INFER_META
(
phi
::
TransposeGradInferMeta
));
DECLARE_INFER_SHAPE_FUNCTOR
(
transpose2_grad
,
Transpose2GradInferShapeFunctor
,
PD_INFER_META
(
phi
::
TransposeGradInferMeta
));
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
transpose
,
...
...
@@ -368,7 +351,8 @@ REGISTER_OPERATOR(
paddle
::
framework
::
DefaultGradOpMaker
<
paddle
::
imperative
::
OpBase
,
true
>
);
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
,
ops
::
TransposeGradInferVarType
);
ops
::
TransposeGradInferVarType
,
TransposeGradInferShapeFunctor
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
...
...
@@ -379,4 +363,5 @@ REGISTER_OPERATOR(transpose2_grad,
ops
::
Transpose2OpGrad
,
ops
::
TransposeGradInferVarType
,
ops
::
Transpose2DoubleGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
Transpose2DoubleGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
Transpose2DoubleGradMaker
<
paddle
::
imperative
::
OpBase
>
,
Transpose2GradInferShapeFunctor
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录