Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d4b007af
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d4b007af
编写于
3月 08, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
3月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add share dims (#40238)
上级
c39aa18e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
41 addition
and
17 deletion
+41
-17
paddle/fluid/framework/infershape_utils.cc
paddle/fluid/framework/infershape_utils.cc
+11
-9
paddle/phi/core/meta_tensor.cc
paddle/phi/core/meta_tensor.cc
+28
-7
paddle/phi/core/meta_tensor.h
paddle/phi/core/meta_tensor.h
+2
-1
未找到文件。
paddle/fluid/framework/infershape_utils.cc
浏览文件 @
d4b007af
...
...
@@ -232,16 +232,8 @@ class CompatMetaTensor : public phi::MetaTensor {
}
}
void
share_
meta
(
const
MetaTensor
&
meta_tensor
)
override
{
void
share_
dims
(
const
MetaTensor
&
meta_tensor
)
override
{
set_dims
(
meta_tensor
.
dims
());
set_dtype
(
meta_tensor
.
dtype
());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
// special case 1: share lod of LoDTensor
share_lod
(
meta_tensor
);
// special case 2: share height and rows of SelectedRows in runtime
if
(
is_runtime_
)
{
auto
*
var
=
BOOST_GET
(
Variable
*
,
var_
);
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
...
...
@@ -254,6 +246,16 @@ class CompatMetaTensor : public phi::MetaTensor {
}
}
void
share_meta
(
const
MetaTensor
&
meta_tensor
)
override
{
set_dtype
(
meta_tensor
.
dtype
());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
// special case 1: share lod of LoDTensor
share_lod
(
meta_tensor
);
share_dims
(
meta_tensor
);
}
private:
const
LoD
&
GetRuntimeLoD
()
const
{
auto
*
var
=
BOOST_GET_CONST
(
Variable
*
,
var_
);
...
...
paddle/phi/core/meta_tensor.cc
浏览文件 @
d4b007af
...
...
@@ -98,13 +98,9 @@ const LoD& MetaTensor::lod() const {
}
void
MetaTensor
::
share_meta
(
const
MetaTensor
&
meta_tensor
)
{
if
(
phi
::
DenseTensor
::
classof
(
tensor_
))
{
set_dims
(
meta_tensor
.
dims
());
set_dtype
(
meta_tensor
.
dtype
());
set_layout
(
meta_tensor
.
layout
());
share_lod
(
meta_tensor
);
}
else
if
(
phi
::
SelectedRows
::
classof
(
tensor_
))
{
set_dims
(
meta_tensor
.
dims
());
if
(
phi
::
DenseTensor
::
classof
(
tensor_
)
||
phi
::
SelectedRows
::
classof
(
tensor_
))
{
share_dims
(
meta_tensor
);
set_dtype
(
meta_tensor
.
dtype
());
set_layout
(
meta_tensor
.
layout
());
share_lod
(
meta_tensor
);
...
...
@@ -114,4 +110,29 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
}
}
TensorBase
*
MetaTensor
::
get_tensor
()
const
{
return
tensor_
;
}
void
MetaTensor
::
share_dims
(
const
MetaTensor
&
meta_tensor
)
{
bool
is_dense_tensor
=
phi
::
DenseTensor
::
classof
(
tensor_
);
bool
is_selected_rows
=
phi
::
SelectedRows
::
classof
(
tensor_
);
if
(
is_dense_tensor
||
is_selected_rows
)
{
set_dims
(
meta_tensor
.
dims
());
if
(
is_selected_rows
)
{
const
auto
in_tensor_base
=
meta_tensor
.
get_tensor
();
PADDLE_ENFORCE_EQ
(
phi
::
SelectedRows
::
classof
(
in_tensor_base
),
true
,
errors
::
InvalidArgument
(
"The input MetaTensor is SelectedRows, but "
"the output MetaTensor is not this type."
));
auto
*
selected_rows_out
=
static_cast
<
SelectedRows
*>
(
tensor_
);
auto
*
selected_rows_in
=
static_cast
<
SelectedRows
*>
(
in_tensor_base
);
selected_rows_out
->
set_rows
(
selected_rows_in
->
rows
());
selected_rows_out
->
set_height
(
selected_rows_in
->
height
());
}
}
else
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Unsupported sharing dims for `%s`."
,
tensor_
->
type_info
().
name
()));
}
}
}
// namespace phi
paddle/phi/core/meta_tensor.h
浏览文件 @
d4b007af
...
...
@@ -60,12 +60,13 @@ class MetaTensor {
virtual
void
share_lod
(
const
MetaTensor
&
meta_tensor
);
virtual
void
share_meta
(
const
MetaTensor
&
meta_tensor
);
virtual
void
share_dims
(
const
MetaTensor
&
meta_tensor
);
private:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const
LoD
&
lod
()
const
;
TensorBase
*
get_tensor
()
const
;
TensorBase
*
tensor_
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录