Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b2390438
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b2390438
编写于
4月 13, 2022
作者:
Z
zyfncg
提交者:
GitHub
4月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix problem of infermeta with vector output (#41646)
* remove stack_grad infershape * fix bug of output with null * fix bug
上级
5f2c5b9e
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
76 addition
and
67 deletion
+76
-67
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+13
-3
paddle/fluid/framework/new_executor/new_executor_defs.cc
paddle/fluid/framework/new_executor/new_executor_defs.cc
+11
-6
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+2
-1
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+14
-5
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+11
-5
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+2
-1
paddle/fluid/imperative/infer_shape_context.h
paddle/fluid/imperative/infer_shape_context.h
+15
-5
paddle/fluid/operators/stack_op.cc
paddle/fluid/operators/stack_op.cc
+4
-39
paddle/phi/infermeta/backward.cc
paddle/phi/infermeta/backward.cc
+4
-2
未找到文件。
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
b2390438
...
...
@@ -597,7 +597,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
for
(
auto
&
out_name
:
output_names
)
{
if
(
ctx
->
HasOutputs
(
out_name
))
{
if
(
ctx
->
HasOutputs
(
out_name
,
true
))
{
auto
output_var
=
ctx
->
GetOutputVarPtrs
(
out_name
);
if
(
output_var
.
size
()
==
1
)
{
infer_meta_context
.
EmplaceBackOutput
(
std
::
make_shared
<
CompatMetaTensor
>
(
...
...
@@ -606,8 +606,18 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
paddle
::
SmallVector
<
std
::
shared_ptr
<
phi
::
MetaTensor
>>
outputs
;
outputs
.
reserve
(
output_var
.
size
());
for
(
const
auto
&
out
:
output_var
)
{
if
(
ctx
->
IsRuntime
())
{
if
(
BOOST_GET_CONST
(
Variable
*
,
out
))
{
outputs
.
emplace_back
(
std
::
make_shared
<
CompatMetaTensor
>
(
out
,
ctx
->
IsRuntime
()));
continue
;
}
}
else
if
(
BOOST_GET_CONST
(
VarDesc
*
,
out
))
{
outputs
.
emplace_back
(
std
::
make_shared
<
CompatMetaTensor
>
(
out
,
ctx
->
IsRuntime
()));
continue
;
}
outputs
.
emplace_back
(
nullptr
);
}
infer_meta_context
.
EmplaceBackOutputs
(
std
::
move
(
outputs
));
}
...
...
paddle/fluid/framework/new_executor/new_executor_defs.cc
浏览文件 @
b2390438
...
...
@@ -93,19 +93,24 @@ bool InterpretercoreInferShapeContext::HasInputs(
return
true
;
}
bool
InterpretercoreInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
)
const
{
bool
InterpretercoreInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
)
const
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
if
(
allow_null
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
if
(
output
!=
nullptr
)
return
true
;
}
return
false
;
}
else
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
return
false
;
}
return
true
;
}
}
AttrReader
InterpretercoreInferShapeContext
::
Attrs
()
const
{
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
b2390438
...
...
@@ -58,7 +58,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
;
AttrReader
Attrs
()
const
override
;
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
b2390438
...
...
@@ -39,7 +39,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
;
AttrReader
Attrs
()
const
override
;
...
...
@@ -882,7 +883,8 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
return
true
;
}
bool
CompileTimeInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
)
const
{
bool
CompileTimeInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
)
const
{
if
(
op_
.
Outputs
().
find
(
name
)
==
op_
.
Outputs
().
end
())
{
return
false
;
}
...
...
@@ -890,10 +892,17 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
if
(
output_names
.
empty
())
{
return
false
;
}
if
(
allow_null
)
{
for
(
auto
&
output
:
output_names
)
{
if
(
block_
.
HasVarRecursive
(
output
))
return
true
;
}
return
false
;
}
else
{
for
(
auto
&
output
:
output_names
)
{
if
(
!
block_
.
HasVarRecursive
(
output
))
return
false
;
}
return
true
;
}
}
AttrReader
CompileTimeInferShapeContext
::
Attrs
()
const
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
b2390438
...
...
@@ -718,19 +718,25 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
if
(
allow_null
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
if
(
output
!=
nullptr
)
return
true
;
}
return
false
;
}
else
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
return
false
;
}
return
true
;
}
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
b2390438
...
...
@@ -69,7 +69,8 @@ class InferShapeContext {
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
=
0
;
virtual
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/fluid/imperative/infer_shape_context.h
浏览文件 @
b2390438
...
...
@@ -95,11 +95,20 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
{
auto
it
=
var_map_out_
->
find
(
name
);
if
(
it
==
var_map_out_
->
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
if
(
allow_null
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
!=
nullptr
)
{
return
true
;
}
}
return
false
;
}
else
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
...
...
@@ -107,6 +116,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}
return
true
;
}
}
framework
::
AttrReader
Attrs
()
const
override
{
return
framework
::
AttrReader
(
*
attrs_
,
*
default_attrs_
);
...
...
paddle/fluid/operators/stack_op.cc
浏览文件 @
b2390438
...
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
plat
=
paddle
::
platform
;
...
...
@@ -68,44 +69,6 @@ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inp
class
StackOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Y@Grad) not exist."
));
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
auto
dy_dim
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
dy_dim
.
size
();
PADDLE_ENFORCE_GE
(
axis
,
-
rank
,
platform
::
errors
::
InvalidArgument
(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d."
,
rank
,
axis
));
PADDLE_ENFORCE_LT
(
axis
,
rank
,
platform
::
errors
::
InvalidArgument
(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d."
,
rank
,
axis
));
if
(
axis
<
0
)
axis
+=
rank
;
PADDLE_ENFORCE_EQ
(
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
size
(),
static_cast
<
size_t
>
(
dy_dim
[
axis
]),
platform
::
errors
::
InvalidArgument
(
"Number of Outputs(X@Grad) is equal to dy dim at axis, but"
" received outputs size is:%d, dy dims is:%d."
,
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
size
(),
static_cast
<
size_t
>
(
dy_dim
[
axis
])));
auto
vec
=
phi
::
vectorize
<
int
>
(
dy_dim
);
vec
.
erase
(
vec
.
begin
()
+
axis
);
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
std
::
vector
<
framework
::
DDim
>
(
dy_dim
[
axis
],
phi
::
make_ddim
(
vec
)));
}
};
template
<
typename
T
>
...
...
@@ -127,8 +90,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INFER_SHAPE_FUNCTOR
(
stack
,
StackInferMetaFunctor
,
PD_INFER_META
(
phi
::
StackInferMeta
));
DECLARE_INFER_SHAPE_FUNCTOR
(
stack_grad
,
StackGradInferMetaFunctor
,
PD_INFER_META
(
phi
::
StackGradInferMeta
));
REGISTER_OPERATOR
(
stack
,
ops
::
StackOp
,
ops
::
StackOpMaker
,
ops
::
StackGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
StackGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
StackInferMetaFunctor
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
,
StackGradInferMetaFunctor
);
paddle/phi/infermeta/backward.cc
浏览文件 @
b2390438
...
...
@@ -541,9 +541,11 @@ void StackGradInferMeta(const MetaTensor& out_grad,
vec
.
erase
(
vec
.
begin
()
+
axis
);
for
(
size_t
i
=
0
;
i
<
x_grad
.
size
();
++
i
)
{
if
(
x_grad
[
i
])
{
x_grad
[
i
]
->
set_dims
(
phi
::
make_ddim
(
vec
));
x_grad
[
i
]
->
set_dtype
(
out_grad
.
dtype
());
}
}
}
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录