Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b2390438
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b2390438
编写于
4月 13, 2022
作者:
Z
zyfncg
提交者:
GitHub
4月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix problem of infermeta with vector output (#41646)
* remove stack_grad infershape * fix bug of output with null * fix bug
上级
5f2c5b9e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
76 addition
and
67 deletion
+76
-67
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+13
-3
paddle/fluid/framework/new_executor/new_executor_defs.cc
paddle/fluid/framework/new_executor/new_executor_defs.cc
+11
-6
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+2
-1
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+14
-5
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+11
-5
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+2
-1
paddle/fluid/imperative/infer_shape_context.h
paddle/fluid/imperative/infer_shape_context.h
+15
-5
paddle/fluid/operators/stack_op.cc
paddle/fluid/operators/stack_op.cc
+4
-39
paddle/phi/infermeta/backward.cc
paddle/phi/infermeta/backward.cc
+4
-2
未找到文件。
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
b2390438
...
...
@@ -597,7 +597,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
for
(
auto
&
out_name
:
output_names
)
{
if
(
ctx
->
HasOutputs
(
out_name
))
{
if
(
ctx
->
HasOutputs
(
out_name
,
true
))
{
auto
output_var
=
ctx
->
GetOutputVarPtrs
(
out_name
);
if
(
output_var
.
size
()
==
1
)
{
infer_meta_context
.
EmplaceBackOutput
(
std
::
make_shared
<
CompatMetaTensor
>
(
...
...
@@ -606,8 +606,18 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
paddle
::
SmallVector
<
std
::
shared_ptr
<
phi
::
MetaTensor
>>
outputs
;
outputs
.
reserve
(
output_var
.
size
());
for
(
const
auto
&
out
:
output_var
)
{
outputs
.
emplace_back
(
std
::
make_shared
<
CompatMetaTensor
>
(
out
,
ctx
->
IsRuntime
()));
if
(
ctx
->
IsRuntime
())
{
if
(
BOOST_GET_CONST
(
Variable
*
,
out
))
{
outputs
.
emplace_back
(
std
::
make_shared
<
CompatMetaTensor
>
(
out
,
ctx
->
IsRuntime
()));
continue
;
}
}
else
if
(
BOOST_GET_CONST
(
VarDesc
*
,
out
))
{
outputs
.
emplace_back
(
std
::
make_shared
<
CompatMetaTensor
>
(
out
,
ctx
->
IsRuntime
()));
continue
;
}
outputs
.
emplace_back
(
nullptr
);
}
infer_meta_context
.
EmplaceBackOutputs
(
std
::
move
(
outputs
));
}
...
...
paddle/fluid/framework/new_executor/new_executor_defs.cc
浏览文件 @
b2390438
...
...
@@ -93,19 +93,24 @@ bool InterpretercoreInferShapeContext::HasInputs(
return
true
;
}
bool
InterpretercoreInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
)
const
{
bool
InterpretercoreInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
)
const
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
if
(
allow_null
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
!=
nullptr
)
return
true
;
}
return
false
;
}
else
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
return
false
;
}
return
true
;
}
return
true
;
}
AttrReader
InterpretercoreInferShapeContext
::
Attrs
()
const
{
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
b2390438
...
...
@@ -58,7 +58,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
;
AttrReader
Attrs
()
const
override
;
...
...
paddle/fluid/framework/op_desc.cc
浏览文件 @
b2390438
...
...
@@ -39,7 +39,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
;
AttrReader
Attrs
()
const
override
;
...
...
@@ -882,7 +883,8 @@ bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
return
true
;
}
bool
CompileTimeInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
)
const
{
bool
CompileTimeInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
)
const
{
if
(
op_
.
Outputs
().
find
(
name
)
==
op_
.
Outputs
().
end
())
{
return
false
;
}
...
...
@@ -890,10 +892,17 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
if
(
output_names
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
output_names
)
{
if
(
!
block_
.
HasVarRecursive
(
output
))
return
false
;
if
(
allow_null
)
{
for
(
auto
&
output
:
output_names
)
{
if
(
block_
.
HasVarRecursive
(
output
))
return
true
;
}
return
false
;
}
else
{
for
(
auto
&
output
:
output_names
)
{
if
(
!
block_
.
HasVarRecursive
(
output
))
return
false
;
}
return
true
;
}
return
true
;
}
AttrReader
CompileTimeInferShapeContext
::
Attrs
()
const
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
b2390438
...
...
@@ -718,18 +718,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
if
(
allow_null
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
!=
nullptr
)
return
true
;
}
return
false
;
}
else
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
return
false
;
}
return
true
;
}
return
true
;
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
b2390438
...
...
@@ -69,7 +69,8 @@ class InferShapeContext {
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
=
0
;
virtual
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
=
0
;
...
...
paddle/fluid/imperative/infer_shape_context.h
浏览文件 @
b2390438
...
...
@@ -95,17 +95,27 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutputs
(
const
std
::
string
&
name
,
bool
allow_null
=
false
)
const
override
{
auto
it
=
var_map_out_
->
find
(
name
);
if
(
it
==
var_map_out_
->
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
if
(
allow_null
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
!=
nullptr
)
{
return
true
;
}
}
return
false
;
}
else
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
return
true
;
}
framework
::
AttrReader
Attrs
()
const
override
{
...
...
paddle/fluid/operators/stack_op.cc
浏览文件 @
b2390438
...
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
plat
=
paddle
::
platform
;
...
...
@@ -68,44 +69,6 @@ Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inp
class
StackOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Y@Grad) not exist."
));
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
auto
dy_dim
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
dy_dim
.
size
();
PADDLE_ENFORCE_GE
(
axis
,
-
rank
,
platform
::
errors
::
InvalidArgument
(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d."
,
rank
,
axis
));
PADDLE_ENFORCE_LT
(
axis
,
rank
,
platform
::
errors
::
InvalidArgument
(
"Attr(axis) must be inside [-rank, rank), where rank = %d, "
"but received axis is:%d."
,
rank
,
axis
));
if
(
axis
<
0
)
axis
+=
rank
;
PADDLE_ENFORCE_EQ
(
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
size
(),
static_cast
<
size_t
>
(
dy_dim
[
axis
]),
platform
::
errors
::
InvalidArgument
(
"Number of Outputs(X@Grad) is equal to dy dim at axis, but"
" received outputs size is:%d, dy dims is:%d."
,
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
size
(),
static_cast
<
size_t
>
(
dy_dim
[
axis
])));
auto
vec
=
phi
::
vectorize
<
int
>
(
dy_dim
);
vec
.
erase
(
vec
.
begin
()
+
axis
);
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
std
::
vector
<
framework
::
DDim
>
(
dy_dim
[
axis
],
phi
::
make_ddim
(
vec
)));
}
};
template
<
typename
T
>
...
...
@@ -127,8 +90,10 @@ class StackGradOpMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INFER_SHAPE_FUNCTOR
(
stack
,
StackInferMetaFunctor
,
PD_INFER_META
(
phi
::
StackInferMeta
));
DECLARE_INFER_SHAPE_FUNCTOR
(
stack_grad
,
StackGradInferMetaFunctor
,
PD_INFER_META
(
phi
::
StackGradInferMeta
));
REGISTER_OPERATOR
(
stack
,
ops
::
StackOp
,
ops
::
StackOpMaker
,
ops
::
StackGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
StackGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
StackInferMetaFunctor
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
,
StackGradInferMetaFunctor
);
paddle/phi/infermeta/backward.cc
浏览文件 @
b2390438
...
...
@@ -541,8 +541,10 @@ void StackGradInferMeta(const MetaTensor& out_grad,
vec
.
erase
(
vec
.
begin
()
+
axis
);
for
(
size_t
i
=
0
;
i
<
x_grad
.
size
();
++
i
)
{
x_grad
[
i
]
->
set_dims
(
phi
::
make_ddim
(
vec
));
x_grad
[
i
]
->
set_dtype
(
out_grad
.
dtype
());
if
(
x_grad
[
i
])
{
x_grad
[
i
]
->
set_dims
(
phi
::
make_ddim
(
vec
));
x_grad
[
i
]
->
set_dtype
(
out_grad
.
dtype
());
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录