Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0d28ee29
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
0d28ee29
编写于
4月 21, 2022
作者:
W
Wilber
提交者:
GitHub
4月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
infer add io stream. (#42031)
* infer add io stream. * add macro
上级
f2f1de7b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
173 addition
and
1 deletion
+173
-1
cmake/external/lite.cmake
cmake/external/lite.cmake
+1
-1
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+18
-0
paddle/fluid/inference/api/analysis_predictor.h
paddle/fluid/inference/api/analysis_predictor.h
+4
-0
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+133
-0
paddle/fluid/inference/api/paddle_api.h
paddle/fluid/inference/api/paddle_api.h
+12
-0
paddle/fluid/inference/api/paddle_tensor.h
paddle/fluid/inference/api/paddle_tensor.h
+5
-0
未找到文件。
cmake/external/lite.cmake
浏览文件 @
0d28ee29
...
...
@@ -50,7 +50,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set
(
LITE_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/lite
)
if
(
NOT LITE_GIT_TAG
)
set
(
LITE_GIT_TAG
4ab64daecc11fbf74fffdc6a4733f388472e7d5d
)
set
(
LITE_GIT_TAG
81ef66554099800c143a0feff6e0a491b3b0d12e
)
endif
()
if
(
NOT CUDA_ARCH_NAME
)
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
0d28ee29
...
...
@@ -1931,11 +1931,29 @@ bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p,
#endif
return
false
;
}
void
InternalUtils
::
UpdateConfigInterleaved
(
paddle_infer
::
Config
*
c
,
bool
with_interleaved
)
{
#ifdef PADDLE_WITH_CUDA
c
->
trt_with_interleaved_
=
with_interleaved
;
#endif
}
void
InternalUtils
::
SyncStream
(
paddle_infer
::
Predictor
*
p
)
{
#ifdef PADDLE_WITH_CUDA
auto
*
pred
=
dynamic_cast
<
paddle
::
AnalysisPredictor
*>
(
p
->
predictor_
.
get
());
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
reinterpret_cast
<
paddle
::
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
pred
->
place_
));
cudaStreamSynchronize
(
dev_ctx
->
stream
());
#endif
}
void
InternalUtils
::
SyncStream
(
cudaStream_t
stream
)
{
#ifdef PADDLE_WITH_CUDA
cudaStreamSynchronize
(
stream
);
#endif
}
}
// namespace experimental
}
// namespace paddle_infer
paddle/fluid/inference/api/analysis_predictor.h
浏览文件 @
0d28ee29
...
...
@@ -38,6 +38,9 @@
namespace
paddle_infer
{
using
float16
=
paddle
::
platform
::
float16
;
namespace
experimental
{
class
InternalUtils
;
};
}
///
/// \file analysis_predictor.h
...
...
@@ -492,6 +495,7 @@ class AnalysisPredictor : public PaddlePredictor {
std
::
shared_ptr
<
distributed
::
FleetExecutor
>
fleet_exe_
;
std
::
shared_ptr
<
distributed
::
TaskNode
>
task_node_
;
#endif
friend
class
paddle_infer
::
experimental
::
InternalUtils
;
};
}
// namespace paddle
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
0d28ee29
...
...
@@ -714,4 +714,137 @@ template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template
void
Tensor
::
ORTCopyToCpu
<
float16
>(
float16
*
data
)
const
;
#endif
namespace
experimental
{
template
<
typename
T
>
void
InternalUtils
::
CopyFromCpuWithIoStream
(
paddle_infer
::
Tensor
*
t
,
const
T
*
data
,
cudaStream_t
stream
)
{
if
(
t
->
tensor_
==
nullptr
)
{
PADDLE_ENFORCE_EQ
(
t
->
name_
.
empty
(),
false
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."
));
auto
*
scope
=
static_cast
<
paddle
::
framework
::
Scope
*>
(
t
->
scope_
);
auto
*
var
=
scope
->
FindVar
(
t
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"No tensor called [%s] in the runtime scope"
,
t
->
name_
));
auto
*
tensor
=
var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
t
->
tensor_
=
tensor
;
}
auto
*
tensor
=
static_cast
<
paddle
::
framework
::
LoDTensor
*>
(
t
->
tensor_
);
PADDLE_ENFORCE_GE
(
tensor
->
numel
(),
0
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."
));
size_t
ele_size
=
tensor
->
numel
()
*
sizeof
(
T
);
if
(
t
->
place_
==
PlaceType
::
kCPU
)
{
auto
*
t_data
=
tensor
->
mutable_data
<
T
>
(
paddle
::
platform
::
CPUPlace
());
std
::
memcpy
(
static_cast
<
void
*>
(
t_data
),
data
,
ele_size
);
}
else
if
(
t
->
place_
==
PlaceType
::
kGPU
)
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle
::
platform
::
CUDAPlace
gpu_place
(
t
->
device_
);
auto
*
t_data
=
tensor
->
mutable_data
<
T
>
(
gpu_place
);
paddle
::
memory
::
Copy
(
gpu_place
,
static_cast
<
void
*>
(
t_data
),
paddle
::
platform
::
CPUPlace
(),
data
,
ele_size
,
stream
);
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unavailable
(
"Can not create tensor with CUDA place because paddle is not compiled "
"with CUDA."
));
#endif
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"CopyFromCpuWithIoStream only supports CPU and GPU now."
));
}
}
template
<
typename
T
>
void
InternalUtils
::
CopyToCpuWithIoStream
(
paddle_infer
::
Tensor
*
t
,
T
*
data
,
cudaStream_t
stream
)
{
if
(
t
->
tensor_
==
nullptr
)
{
PADDLE_ENFORCE_EQ
(
t
->
name_
.
empty
(),
false
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."
));
auto
*
scope
=
static_cast
<
paddle
::
framework
::
Scope
*>
(
t
->
scope_
);
auto
*
var
=
scope
->
FindVar
(
t
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
paddle
::
platform
::
errors
::
PreconditionNotMet
(
"No tensor called [%s] in the runtime scope"
,
t
->
name_
));
auto
*
tensor
=
var
->
GetMutable
<
paddle
::
framework
::
LoDTensor
>
();
t
->
tensor_
=
tensor
;
}
auto
*
tensor
=
static_cast
<
paddle
::
framework
::
LoDTensor
*>
(
t
->
tensor_
);
auto
ele_num
=
tensor
->
numel
();
auto
*
t_data
=
tensor
->
data
<
T
>
();
auto
t_place
=
tensor
->
place
();
paddle
::
framework
::
Tensor
out
;
auto
mem_allocation
=
std
::
make_shared
<
paddle
::
memory
::
allocation
::
Allocation
>
(
static_cast
<
void
*>
(
data
),
ele_num
*
sizeof
(
T
),
paddle
::
platform
::
CPUPlace
());
out
.
ResetHolder
(
mem_allocation
);
if
(
paddle
::
platform
::
is_cpu_place
(
t_place
))
{
#ifdef PADDLE_WITH_MKLDNN
if
(
tensor
->
layout
()
==
paddle
::
framework
::
DataLayout
::
kMKLDNN
)
paddle
::
framework
::
innerTransDataLayoutFromMKLDNN
(
tensor
->
layout
(),
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
.
get_cur_paddle_data_layout
(),
*
tensor
,
&
out
,
paddle
::
platform
::
CPUPlace
(),
true
);
else
std
::
memcpy
(
static_cast
<
void
*>
(
data
),
t_data
,
ele_num
*
sizeof
(
T
));
#else
std
::
memcpy
(
static_cast
<
void
*>
(
data
),
t_data
,
ele_num
*
sizeof
(
T
));
#endif
}
else
if
(
t
->
place_
==
PlaceType
::
kGPU
)
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle
::
memory
::
Copy
(
paddle
::
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
data
),
t_place
,
t_data
,
ele_num
*
sizeof
(
T
),
stream
);
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unavailable
(
"Can not create tensor with CUDA place because paddle is not compiled "
"with CUDA."
));
#endif
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"CopyToCpuWithIoStream only supports CPU and GPU now."
));
}
}
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
float
>(
paddle_infer
::
Tensor
*
t
,
const
float
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
int64_t
>(
paddle_infer
::
Tensor
*
t
,
const
int64_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
int32_t
>(
paddle_infer
::
Tensor
*
t
,
const
int32_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
uint8_t
>(
paddle_infer
::
Tensor
*
t
,
const
uint8_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
int8_t
>(
paddle_infer
::
Tensor
*
t
,
const
int8_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyFromCpuWithIoStream
<
float16
>(
paddle_infer
::
Tensor
*
t
,
const
float16
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
float
>(
paddle_infer
::
Tensor
*
t
,
float
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
int64_t
>(
paddle_infer
::
Tensor
*
t
,
int64_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
int32_t
>(
paddle_infer
::
Tensor
*
t
,
int32_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
uint8_t
>(
paddle_infer
::
Tensor
*
t
,
uint8_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
int8_t
>(
paddle_infer
::
Tensor
*
t
,
int8_t
*
data
,
cudaStream_t
stream
);
template
void
InternalUtils
::
CopyToCpuWithIoStream
<
float16
>(
paddle_infer
::
Tensor
*
t
,
float16
*
data
,
cudaStream_t
stream
);
}
// namespace experimental
}
// namespace paddle_infer
paddle/fluid/inference/api/paddle_api.h
浏览文件 @
0d28ee29
...
...
@@ -420,8 +420,10 @@ using hipStream_t = struct ihipStream_t*;
namespace
paddle_infer
{
class
Predictor
;
class
Tensor
;
using
Config
=
paddle
::
AnalysisConfig
;
namespace
experimental
{
// Unstable interface, may be modified or deleted in the future.
class
PD_INFER_DECL
InternalUtils
{
public:
// Note: Can only be used under thread_local semantics.
...
...
@@ -429,8 +431,18 @@ class PD_INFER_DECL InternalUtils {
cudaStream_t
stream
);
static
bool
RunWithExternalStream
(
paddle_infer
::
Predictor
*
pred
,
hipStream_t
stream
);
static
void
UpdateConfigInterleaved
(
paddle_infer
::
Config
*
c
,
bool
with_interleaved
);
static
void
SyncStream
(
paddle_infer
::
Predictor
*
pred
);
static
void
SyncStream
(
cudaStream_t
stream
);
template
<
typename
T
>
static
void
CopyFromCpuWithIoStream
(
paddle_infer
::
Tensor
*
t
,
const
T
*
data
,
cudaStream_t
stream
);
template
<
typename
T
>
static
void
CopyToCpuWithIoStream
(
paddle_infer
::
Tensor
*
t
,
T
*
data
,
cudaStream_t
stream
);
};
}
// namespace experimental
}
// namespace paddle_infer
paddle/fluid/inference/api/paddle_tensor.h
浏览文件 @
0d28ee29
...
...
@@ -39,6 +39,10 @@ namespace contrib {
class
TensorUtils
;
}
namespace
experimental
{
class
InternalUtils
;
};
/// \brief Paddle data type.
enum
DataType
{
FLOAT32
,
...
...
@@ -198,6 +202,7 @@ class PD_INFER_DECL Tensor {
#endif
friend
class
paddle_infer
::
contrib
::
TensorUtils
;
friend
class
paddle_infer
::
experimental
::
InternalUtils
;
#if defined(PADDLE_WITH_TESTING) && defined(PADDLE_WITH_INFERENCE_API_TEST)
friend
class
paddle_infer
::
InferApiTesterUtils
;
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录