Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ef53e1b4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef53e1b4
编写于
9月 08, 2022
作者:
T
TeFeng Chen
提交者:
GitHub
9月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cinn_launch op: fix dtype of tensor is always mutable_data<float> (#45835)
上级
2b0857be
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
8 addition
and
5 deletion
+8
-5
paddle/fluid/operators/cinn/cinn_launch_context.cc
paddle/fluid/operators/cinn/cinn_launch_context.cc
+8
-5
未找到文件。
paddle/fluid/operators/cinn/cinn_launch_context.cc
浏览文件 @
ef53e1b4
...
...
@@ -270,8 +270,9 @@ void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
[
this
,
var_name
](
void
*
ctx
,
cinn_buffer_t
*
buffer
)
{
auto
*
tensor
=
cached_scope_
->
GetVar
(
var_name
)
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
framework
::
DDim
(
buffer
->
dims
,
buffer
->
dimensions
));
buffer
->
memory
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
mutable_data
<
float
>
(
*
cached_place_
));
buffer
->
memory
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
mutable_data
(
*
cached_place_
,
framework
::
paddle2cinn
::
TransToPaddleDataType
(
buffer
->
type
)));
return
0
;
});
...
...
@@ -295,8 +296,9 @@ void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) {
auto
*
tensor
=
cached_temp_scope_
->
Var
(
var_name
)
->
GetMutable
<
LoDTensor
>
();
tensor
->
Resize
(
framework
::
DDim
(
buffer
->
dims
,
buffer
->
dimensions
));
buffer
->
memory
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
mutable_data
<
float
>
(
*
cached_place_
));
buffer
->
memory
=
reinterpret_cast
<
uint8_t
*>
(
tensor
->
mutable_data
(
*
cached_place_
,
framework
::
paddle2cinn
::
TransToPaddleDataType
(
buffer
->
type
)));
return
0
;
});
...
...
@@ -437,7 +439,8 @@ ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place,
auto
*
buffer
=
GetCinnBufferOfVar
(
var_name
);
auto
dim
=
framework
::
DDim
(
buffer
->
dims
,
buffer
->
dimensions
);
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
var
->
GetMutable
<
LoDTensor
>
()
->
mutable_data
<
float
>
(
place
);
var
->
GetMutable
<
LoDTensor
>
()
->
mutable_data
(
place
,
framework
::
paddle2cinn
::
TransToPaddleDataType
(
buffer
->
type
));
}
return
parallel_executor_
.
get
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录