Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e426cdae
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e426cdae
编写于
9月 25, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
9月 25, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix inference output with lod (#13557)
上级
bc1fa4fd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
29 addition
and
68 deletion
+29
-68
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+15
-51
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+11
-8
paddle/fluid/inference/tests/api/tester_helper.h
paddle/fluid/inference/tests/api/tester_helper.h
+3
-9
未找到文件。
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
e426cdae
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/timer.h"
#include "paddle/fluid/inference/api/timer.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -215,57 +216,20 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
...
@@ -215,57 +216,20 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
template
<
typename
T
>
template
<
typename
T
>
void
NativePaddlePredictor
::
GetFetchOne
(
const
framework
::
LoDTensor
&
fetch
,
void
NativePaddlePredictor
::
GetFetchOne
(
const
framework
::
LoDTensor
&
fetch
,
PaddleTensor
*
output
)
{
PaddleTensor
*
output
)
{
std
::
vector
<
int
>
shape
;
// set shape.
auto
dims_i
=
fetch
.
dims
();
auto
shape
=
framework
::
vectorize
(
fetch
.
dims
());
auto
lod
=
fetch
.
lod
();
output
->
shape
.
assign
(
shape
.
begin
(),
shape
.
end
());
const
T
*
output_ptr
=
fetch
.
data
<
T
>
();
// set data.
auto
num
=
fetch
.
numel
();
const
T
*
data
=
fetch
.
data
<
T
>
();
std
::
vector
<
T
>
data
;
int
num_elems
=
inference
::
VecReduceToInt
(
shape
);
if
(
0
==
lod
.
size
())
{
output
->
data
.
Resize
(
num_elems
*
sizeof
(
T
));
std
::
copy
(
output_ptr
,
output_ptr
+
num
,
std
::
back_inserter
(
data
));
// The fetched tensor output by fetch op, should always in CPU memory, so just
for
(
int
j
=
0
;
j
<
dims_i
.
size
();
++
j
)
{
// copy.
shape
.
push_back
(
dims_i
[
j
]);
memcpy
(
output
->
data
.
data
(),
data
,
num_elems
*
sizeof
(
T
));
}
// set lod
}
else
{
output
->
lod
.
clear
();
// for batch detection
for
(
auto
&
level
:
fetch
.
lod
())
{
// image[0] -> output[0] shape {145, 6}
output
->
lod
.
emplace_back
(
level
.
begin
(),
level
.
end
());
// image[1] -> output[1] shape {176, 6}
// then,
// the batch output shape {321, 6}
// the lod {{0, 145, 321}}
// so we should append output[0] to {176, 6}
size_t
max_dim
=
0
;
for
(
size_t
j
=
1
;
j
<
lod
[
0
].
size
();
j
++
)
{
max_dim
=
std
::
max
(
max_dim
,
lod
[
0
][
j
]
-
lod
[
0
][
j
-
1
]);
}
size_t
common_dim
=
lod
[
0
].
back
()
==
0
?
0
:
num
/
lod
[
0
].
back
();
if
(
max_dim
>
0
)
{
data
.
resize
((
lod
[
0
].
size
()
-
1
)
*
max_dim
*
common_dim
,
0
);
}
for
(
size_t
j
=
1
;
j
<
lod
[
0
].
size
();
j
++
)
{
size_t
start
=
lod
[
0
][
j
-
1
]
*
common_dim
;
size_t
end
=
lod
[
0
][
j
]
*
common_dim
;
if
(
end
>
start
)
{
std
::
copy
(
output_ptr
+
start
,
output_ptr
+
end
,
data
.
begin
()
+
(
j
-
1
)
*
max_dim
*
common_dim
);
}
}
shape
.
push_back
(
lod
[
0
].
size
()
-
1
);
shape
.
push_back
(
max_dim
);
for
(
int
j
=
1
;
j
<
dims_i
.
size
();
++
j
)
{
shape
.
push_back
(
dims_i
[
j
]);
}
}
output
->
shape
=
shape
;
auto
&
buffer
=
output
->
data
;
if
(
buffer
.
empty
()
||
buffer
.
length
()
<
sizeof
(
T
)
*
data
.
size
())
{
buffer
.
Resize
(
sizeof
(
T
)
*
data
.
size
());
}
std
::
memcpy
(
buffer
.
data
(),
data
.
data
(),
sizeof
(
T
)
*
data
.
size
());
// copy LoD
for
(
const
auto
&
level
:
fetch
.
lod
())
{
output
->
lod
.
emplace_back
(
level
);
}
}
}
}
...
...
paddle/fluid/inference/api/helper.h
浏览文件 @
e426cdae
...
@@ -74,13 +74,17 @@ template <>
...
@@ -74,13 +74,17 @@ template <>
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
std
::
string
to_string
<
std
::
vector
<
std
::
vector
<
float
>>>
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
);
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
&
vec
);
template
<
typename
T
>
int
VecReduceToInt
(
const
std
::
vector
<
T
>
&
v
)
{
return
std
::
accumulate
(
v
.
begin
(),
v
.
end
(),
1
,
[](
T
a
,
T
b
)
{
return
a
*
b
;
});
}
template
<
typename
T
>
template
<
typename
T
>
static
void
TensorAssignData
(
PaddleTensor
*
tensor
,
static
void
TensorAssignData
(
PaddleTensor
*
tensor
,
const
std
::
vector
<
std
::
vector
<
T
>>
&
data
)
{
const
std
::
vector
<
std
::
vector
<
T
>>
&
data
)
{
// Assign buffer
// Assign buffer
int
dim
=
std
::
accumulate
(
tensor
->
shape
.
begin
(),
tensor
->
shape
.
end
(),
1
,
int
num_elems
=
VecReduceToInt
(
tensor
->
shape
);
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
tensor
->
data
.
Resize
(
sizeof
(
T
)
*
num_elems
);
tensor
->
data
.
Resize
(
sizeof
(
T
)
*
dim
);
int
c
=
0
;
int
c
=
0
;
for
(
const
auto
&
f
:
data
)
{
for
(
const
auto
&
f
:
data
)
{
for
(
T
v
:
f
)
{
for
(
T
v
:
f
)
{
...
@@ -89,7 +93,7 @@ static void TensorAssignData(PaddleTensor *tensor,
...
@@ -89,7 +93,7 @@ static void TensorAssignData(PaddleTensor *tensor,
}
}
}
}
std
::
string
DescribeTensor
(
const
PaddleTensor
&
tensor
)
{
st
atic
st
d
::
string
DescribeTensor
(
const
PaddleTensor
&
tensor
)
{
std
::
stringstream
os
;
std
::
stringstream
os
;
os
<<
"Tensor ["
<<
tensor
.
name
<<
"]
\n
"
;
os
<<
"Tensor ["
<<
tensor
.
name
<<
"]
\n
"
;
os
<<
" - type: "
;
os
<<
" - type: "
;
...
@@ -113,8 +117,7 @@ std::string DescribeTensor(const PaddleTensor &tensor) {
...
@@ -113,8 +117,7 @@ std::string DescribeTensor(const PaddleTensor &tensor) {
os
<<
"
\n
"
;
os
<<
"
\n
"
;
os
<<
" - data: "
;
os
<<
" - data: "
;
int
dim
=
std
::
accumulate
(
tensor
.
shape
.
begin
(),
tensor
.
shape
.
end
(),
1
,
int
dim
=
VecReduceToInt
(
tensor
.
shape
);
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
os
<<
static_cast
<
float
*>
(
tensor
.
data
.
data
())[
i
]
<<
" "
;
os
<<
static_cast
<
float
*>
(
tensor
.
data
.
data
())[
i
]
<<
" "
;
}
}
...
@@ -122,8 +125,8 @@ std::string DescribeTensor(const PaddleTensor &tensor) {
...
@@ -122,8 +125,8 @@ std::string DescribeTensor(const PaddleTensor &tensor) {
return
os
.
str
();
return
os
.
str
();
}
}
void
PrintTime
(
int
batch_size
,
int
repeat
,
int
num_threads
,
int
tid
,
static
void
PrintTime
(
int
batch_size
,
int
repeat
,
int
num_threads
,
int
tid
,
double
latency
,
int
epoch
=
1
)
{
double
latency
,
int
epoch
=
1
)
{
LOG
(
INFO
)
<<
"====== batch_size: "
<<
batch_size
<<
", repeat: "
<<
repeat
LOG
(
INFO
)
<<
"====== batch_size: "
<<
batch_size
<<
", repeat: "
<<
repeat
<<
", threads: "
<<
num_threads
<<
", thread id: "
<<
tid
<<
", threads: "
<<
num_threads
<<
", thread id: "
<<
tid
<<
", latency: "
<<
latency
<<
"ms ======"
;
<<
", latency: "
<<
latency
<<
"ms ======"
;
...
...
paddle/fluid/inference/tests/api/tester_helper.h
浏览文件 @
e426cdae
...
@@ -47,11 +47,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
...
@@ -47,11 +47,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
&
out
=
outputs
[
i
];
auto
&
out
=
outputs
[
i
];
auto
&
ref_out
=
ref_outputs
[
i
];
auto
&
ref_out
=
ref_outputs
[
i
];
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
size_t
size
=
VecReduceToInt
(
out
.
shape
);
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
size_t
ref_size
=
VecReduceToInt
(
ref_out
.
shape
);
size_t
ref_size
=
std
::
accumulate
(
ref_out
.
shape
.
begin
(),
ref_out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
EXPECT_GT
(
size
,
0
);
EXPECT_GT
(
size
,
0
);
EXPECT_EQ
(
size
,
ref_size
);
EXPECT_EQ
(
size
,
ref_size
);
EXPECT_EQ
(
out
.
dtype
,
ref_out
.
dtype
);
EXPECT_EQ
(
out
.
dtype
,
ref_out
.
dtype
);
...
@@ -87,10 +84,7 @@ std::unique_ptr<PaddlePredictor> CreateTestPredictor(
...
@@ -87,10 +84,7 @@ std::unique_ptr<PaddlePredictor> CreateTestPredictor(
}
}
}
}
size_t
GetSize
(
const
PaddleTensor
&
out
)
{
size_t
GetSize
(
const
PaddleTensor
&
out
)
{
return
VecReduceToInt
(
out
.
shape
);
}
return
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
}
std
::
unordered_map
<
std
::
string
,
int
>
GetFuseStatis
(
AnalysisConfig
config
,
std
::
unordered_map
<
std
::
string
,
int
>
GetFuseStatis
(
AnalysisConfig
config
,
int
*
num_ops
)
{
int
*
num_ops
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录