Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
cdf3a4c2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cdf3a4c2
编写于
9月 21, 2018
作者:
C
chengduo
提交者:
GitHub
9月 21, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix concat_op InferShape (#13513)
* add ShareLoDs * refine * add Is EmptyVarName * refine Sharedlod
上级
6537b175
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
31 addition
and
2 deletion
+31
-2
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+5
-0
paddle/fluid/framework/shape_inference.cc
paddle/fluid/framework/shape_inference.cc
+10
-0
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+2
-0
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+14
-2
未找到文件。
paddle/fluid/framework/op_desc.cc
浏览文件 @
cdf3a4c2
...
...
@@ -54,6 +54,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
size_t
j
=
0
)
const
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
PADDLE_ENFORCE
(
Inputs
(
in
)[
i
]
!=
framework
::
kEmptyVarName
,
"The %s[%d] is @EMPTY@"
,
in
,
i
);
PADDLE_ENFORCE
(
Outputs
(
out
)[
j
]
!=
framework
::
kEmptyVarName
,
"The %s[%d] is @EMPTY@"
,
out
,
j
);
auto
*
in_var
=
block_
.
FindVarRecursive
(
Inputs
(
in
)[
i
]);
auto
*
out_var
=
block_
.
FindVarRecursive
(
Outputs
(
out
)[
j
]);
if
(
in_var
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
...
...
@@ -63,6 +67,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_EQ
(
in_var
->
GetType
(),
proto
::
VarType
::
LOD_TENSOR
,
"The %d-th output of Output(%s) must be LoDTensor."
,
j
,
out
);
out_var
->
SetLoDLevel
(
in_var
->
GetLoDLevel
());
}
...
...
paddle/fluid/framework/shape_inference.cc
浏览文件 @
cdf3a4c2
...
...
@@ -46,6 +46,16 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return
this
->
GetRepeatedDims
(
arg_names
[
0
]);
}
void
InferShapeContext
::
ShareLoDs
(
const
std
::
string
&
in
,
const
std
::
string
&
out
)
const
{
PADDLE_ENFORCE_EQ
(
Inputs
(
in
).
size
(),
Outputs
(
out
).
size
(),
"The number of arguments in %s and %s is not equal."
,
in
,
out
);
for
(
size_t
i
=
0
;
i
<
in
.
size
();
++
i
)
{
ShareLoD
(
in
,
out
,
i
,
i
);
}
}
DDim
InferShapeContext
::
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
...
...
paddle/fluid/framework/shape_inference.h
浏览文件 @
cdf3a4c2
...
...
@@ -56,6 +56,8 @@ class InferShapeContext {
virtual
const
std
::
vector
<
std
::
string
>
&
Outputs
(
const
std
::
string
&
name
)
const
=
0
;
void
ShareLoDs
(
const
std
::
string
&
in
,
const
std
::
string
&
out
)
const
;
virtual
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
=
0
;
...
...
paddle/fluid/operators/concat_op.cc
浏览文件 @
cdf3a4c2
...
...
@@ -94,8 +94,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputsDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
auto
in_x
=
"X"
;
auto
out_x_g_n
=
framework
::
GradVarName
(
in_x
);
ctx
->
SetOutputsDim
(
out_x_g_n
,
ctx
->
GetInputsDim
(
in_x
));
auto
&
in_names
=
ctx
->
Inputs
(
in_x
);
auto
&
out_names
=
ctx
->
Outputs
(
out_x_g_n
);
PADDLE_ENFORCE_EQ
(
in_names
.
size
(),
out_names
.
size
(),
"The number of arguments in %s[%d] and %s[%d] is not equal."
,
in_x
,
in_names
.
size
(),
out_x_g_n
,
out_names
.
size
());
for
(
size_t
i
=
0
;
i
<
in_names
.
size
();
++
i
)
{
if
(
out_names
[
i
]
!=
framework
::
kEmptyVarName
)
{
ctx
->
ShareLoD
(
in_x
,
out_x_g_n
,
i
,
i
);
}
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录