Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fba6a10d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fba6a10d
编写于
1月 02, 2018
作者:
Q
QI JUN
提交者:
GitHub
1月 02, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug in TransDataLayout (#7137)
上级
06888bb0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
17 addition
and
8 deletion
+17
-8
paddle/framework/data_transform.cc
paddle/framework/data_transform.cc
+10
-1
paddle/framework/data_transform_test.cc
paddle/framework/data_transform_test.cc
+7
-7
未找到文件。
paddle/framework/data_transform.cc
浏览文件 @
fba6a10d
...
@@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx,
...
@@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx,
auto
*
dst
=
out
->
GetMutable
<
Tensor
>
();
auto
*
dst
=
out
->
GetMutable
<
Tensor
>
();
PADDLE_ENFORCE
(
arity
(
src
.
dims
())
==
4
,
"Input Arity Only Suppport 4!"
);
PADDLE_ENFORCE
(
arity
(
src
.
dims
())
==
4
,
"Input Arity Only Suppport 4!"
);
dst
->
Resize
(
src
.
dims
());
auto
src_dim
=
src
.
dims
();
dst
->
Resize
(
src_dim
);
auto
place
=
kernel_pair
.
second
.
place_
;
auto
place
=
kernel_pair
.
second
.
place_
;
CopyFrom
(
src
,
place
,
*
ctx
,
dst
);
CopyFrom
(
src
,
place
,
*
ctx
,
dst
);
const
std
::
vector
<
int
>
axis
=
{
0
,
2
,
3
,
1
};
const
std
::
vector
<
int
>
axis
=
{
0
,
2
,
3
,
1
};
std
::
vector
<
int64_t
>
dst_dim
;
dst_dim
.
resize
(
axis
.
size
());
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
dst_dim
[
i
]
=
src_dim
[
axis
[
i
]];
}
dst
->
Resize
(
make_ddim
(
dst_dim
));
auto
src_type
=
kernel_pair
.
first
.
data_type_
;
auto
src_type
=
kernel_pair
.
first
.
data_type_
;
framework
::
VisitDataType
(
src_type
,
CastDataLayout
(
src
,
dst
,
ctx
,
axis
));
framework
::
VisitDataType
(
src_type
,
CastDataLayout
(
src
,
dst
,
ctx
,
axis
));
...
...
paddle/framework/data_transform_test.cc
浏览文件 @
fba6a10d
...
@@ -32,18 +32,18 @@ using namespace platform;
...
@@ -32,18 +32,18 @@ using namespace platform;
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
*/
*/
std
::
array
<
proto
::
DataType
,
2
>
kDataType
=
{
proto
::
DataType
::
FP32
,
std
::
array
<
proto
::
DataType
,
2
>
kDataType
=
{
proto
::
DataType
::
FP64
};
{
proto
::
DataType
::
FP32
,
proto
::
DataType
::
FP64
}
};
std
::
array
<
Place
,
2
>
kPlace
=
{
CPUPlace
(),
CUDAPlace
(
0
)
};
std
::
array
<
Place
,
2
>
kPlace
=
{
{
CPUPlace
(),
CUDAPlace
(
0
)}
};
std
::
array
<
DataLayout
,
2
>
kDataLayout
=
{
std
::
array
<
DataLayout
,
2
>
kDataLayout
=
{
{
DataLayout
::
kNHWC
,
DataLayout
::
kNCHW
,
DataLayout
::
kNHWC
,
DataLayout
::
kNCHW
,
};
}
}
;
std
::
array
<
LibraryType
,
2
>
kLibraryType
=
{
std
::
array
<
LibraryType
,
2
>
kLibraryType
=
{
{
LibraryType
::
kPlain
,
LibraryType
::
kMKLDNN
,
LibraryType
::
kPlain
,
LibraryType
::
kMKLDNN
,
};
}
}
;
OpKernelType
GenFromBit
(
const
std
::
vector
<
bool
>
bits
)
{
OpKernelType
GenFromBit
(
const
std
::
vector
<
bool
>
bits
)
{
return
OpKernelType
(
kDataType
[
bits
[
0
]],
kPlace
[
bits
[
1
]],
kDataLayout
[
bits
[
2
]],
return
OpKernelType
(
kDataType
[
bits
[
0
]],
kPlace
[
bits
[
1
]],
kDataLayout
[
bits
[
2
]],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录