Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1ff1c1e0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ff1c1e0
编写于
3月 02, 2022
作者:
J
JingZhuangzhuang
提交者:
GitHub
3月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add share external data interface (#39809)
上级
e4dba69a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
182 addition
and
0 deletion
+182
-0
paddle/fluid/inference/api/analysis_predictor_tester.cc
paddle/fluid/inference/api/analysis_predictor_tester.cc
+82
-0
paddle/fluid/inference/api/details/zero_copy_tensor.cc
paddle/fluid/inference/api/details/zero_copy_tensor.cc
+87
-0
paddle/fluid/inference/api/paddle_tensor.h
paddle/fluid/inference/api/paddle_tensor.h
+13
-0
未找到文件。
paddle/fluid/inference/api/analysis_predictor_tester.cc
浏览文件 @
1ff1c1e0
...
...
@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#if defined(PADDLE_WITH_CUDA)
#include <cuda_runtime.h>
#endif
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <thread> // NOLINT
...
...
@@ -405,4 +408,83 @@ TEST(Predictor, Run) {
predictor
->
TryShrinkMemory
();
}
TEST
(
Tensor
,
CpuShareExternalData
)
{
Config
config
;
config
.
SetModel
(
FLAGS_dirname
);
auto
predictor
=
CreatePredictor
(
config
);
auto
w0
=
predictor
->
GetInputHandle
(
"firstw"
);
auto
w1
=
predictor
->
GetInputHandle
(
"secondw"
);
auto
w2
=
predictor
->
GetInputHandle
(
"thirdw"
);
auto
w3
=
predictor
->
GetInputHandle
(
"forthw"
);
std
::
vector
<
std
::
vector
<
int64_t
>>
input_data
(
4
,
{
0
,
1
,
2
,
3
});
w0
->
ShareExternalData
<
int64_t
>
(
input_data
[
0
].
data
(),
{
4
,
1
},
PlaceType
::
kCPU
);
w1
->
ShareExternalData
<
int64_t
>
(
input_data
[
1
].
data
(),
{
4
,
1
},
PlaceType
::
kCPU
);
w2
->
ShareExternalData
<
int64_t
>
(
input_data
[
2
].
data
(),
{
4
,
1
},
PlaceType
::
kCPU
);
w3
->
ShareExternalData
<
int64_t
>
(
input_data
[
3
].
data
(),
{
4
,
1
},
PlaceType
::
kCPU
);
auto
out
=
predictor
->
GetOutputHandle
(
"fc_1.tmp_2"
);
auto
out_shape
=
out
->
shape
();
std
::
vector
<
float
>
out_data
;
out_data
.
resize
(
std
::
accumulate
(
out_shape
.
begin
(),
out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()));
out
->
ShareExternalData
<
float
>
(
out_data
.
data
(),
out_shape
,
PlaceType
::
kCPU
);
predictor
->
Run
();
PlaceType
place
;
int
size
=
0
;
out
->
data
<
float
>
(
&
place
,
&
size
);
LOG
(
INFO
)
<<
"output size: "
<<
size
/
sizeof
(
float
);
predictor
->
TryShrinkMemory
();
}
#if defined(PADDLE_WITH_CUDA)
TEST
(
Tensor
,
GpuShareExternalData
)
{
Config
config
;
config
.
SetModel
(
FLAGS_dirname
);
config
.
EnableUseGpu
(
100
,
0
);
auto
predictor
=
CreatePredictor
(
config
);
auto
w0
=
predictor
->
GetInputHandle
(
"firstw"
);
auto
w1
=
predictor
->
GetInputHandle
(
"secondw"
);
auto
w2
=
predictor
->
GetInputHandle
(
"thirdw"
);
auto
w3
=
predictor
->
GetInputHandle
(
"forthw"
);
std
::
vector
<
std
::
vector
<
int64_t
>>
input_data
(
4
,
{
0
,
1
,
2
,
3
});
std
::
vector
<
int64_t
*>
input_gpu
(
4
,
nullptr
);
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
{
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
input_gpu
[
i
]),
4
*
sizeof
(
int64_t
));
cudaMemcpy
(
input_gpu
[
i
],
input_data
[
i
].
data
(),
4
*
sizeof
(
int64_t
),
cudaMemcpyHostToDevice
);
}
w0
->
ShareExternalData
<
int64_t
>
(
input_gpu
[
0
],
{
4
,
1
},
PlaceType
::
kGPU
);
w1
->
ShareExternalData
<
int64_t
>
(
input_gpu
[
1
],
{
4
,
1
},
PlaceType
::
kGPU
);
w2
->
ShareExternalData
<
int64_t
>
(
input_gpu
[
2
],
{
4
,
1
},
PlaceType
::
kGPU
);
w3
->
ShareExternalData
<
int64_t
>
(
input_gpu
[
3
],
{
4
,
1
},
PlaceType
::
kGPU
);
auto
out
=
predictor
->
GetOutputHandle
(
"fc_1.tmp_2"
);
auto
out_shape
=
out
->
shape
();
float
*
out_data
;
auto
out_size
=
std
::
accumulate
(
out_shape
.
begin
(),
out_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
())
*
sizeof
(
float
);
cudaMalloc
(
reinterpret_cast
<
void
**>
(
out_data
),
out_size
*
sizeof
(
float
));
out
->
ShareExternalData
<
float
>
(
out_data
,
out_shape
,
PlaceType
::
kGPU
);
predictor
->
Run
();
PlaceType
place
;
int
size
=
0
;
out
->
data
<
float
>
(
&
place
,
&
size
);
LOG
(
INFO
)
<<
"output size: "
<<
size
/
sizeof
(
float
);
predictor
->
TryShrinkMemory
();
}
#endif
}
// namespace paddle_infer
paddle/fluid/inference/api/details/zero_copy_tensor.cc
浏览文件 @
1ff1c1e0
...
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/allocator.h"
namespace
paddle_infer
{
...
...
@@ -205,6 +206,73 @@ void Tensor::CopyFromCpu(const T *data) {
}
}
template
<
typename
T
>
struct
DataTypeInfo
;
template
<
>
struct
DataTypeInfo
<
float
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
FLOAT32
;
};
template
<
>
struct
DataTypeInfo
<
float16
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
FLOAT16
;
};
template
<
>
struct
DataTypeInfo
<
int64_t
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
INT64
;
};
template
<
>
struct
DataTypeInfo
<
int8_t
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
INT8
;
};
template
<
>
struct
DataTypeInfo
<
uint8_t
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
UINT8
;
};
template
<
>
struct
DataTypeInfo
<
int32_t
>
{
paddle
::
experimental
::
DataType
TYPE
=
paddle
::
experimental
::
DataType
::
INT32
;
};
paddle
::
experimental
::
DataLayout
LayoutConvert
(
DataLayout
layout
)
{
PADDLE_ENFORCE_EQ
(
layout
,
DataLayout
::
kNCHW
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"Only NCHW is supported now."
));
return
paddle
::
experimental
::
DataLayout
::
NCHW
;
}
template
<
typename
T
>
void
Tensor
::
ShareExternalData
(
const
T
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
)
{
EAGER_GET_TENSOR
(
paddle
::
framework
::
LoDTensor
)
size_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
())
*
sizeof
(
T
);
phi
::
DenseTensorMeta
meta
(
DataTypeInfo
<
T
>
().
TYPE
,
phi
::
make_ddim
(
shape
),
LayoutConvert
(
layout
));
if
(
place
==
PlaceType
::
kCPU
)
{
phi
::
DenseTensor
dtensor
(
std
::
make_shared
<
phi
::
Allocation
>
(
const_cast
<
T
*>
(
data
),
size
,
paddle
::
platform
::
CPUPlace
()),
meta
);
*
tensor
=
std
::
move
(
dtensor
);
}
else
if
(
place
==
PlaceType
::
kGPU
)
{
phi
::
DenseTensor
dtensor
(
std
::
make_shared
<
phi
::
Allocation
>
(
const_cast
<
T
*>
(
data
),
size
,
paddle
::
platform
::
CUDAPlace
(
device_
)),
meta
);
*
tensor
=
std
::
move
(
dtensor
);
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"PlaceType must be PlaceType::kCPU or PlaceType::kGPU."
));
}
}
void
Tensor
::
CopyStringsFromCpu
(
const
paddle_infer
::
Strings
*
data
)
{
EAGER_GET_TENSOR
(
paddle_infer
::
Strings
);
PADDLE_ENFORCE_GE
(
tensor
->
size
(),
0
,
...
...
@@ -334,6 +402,25 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template
PD_INFER_DECL
void
Tensor
::
CopyFromCpu
<
int8_t
>(
const
int8_t
*
data
);
template
PD_INFER_DECL
void
Tensor
::
CopyFromCpu
<
float16
>(
const
float16
*
data
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
float
>(
const
float
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
int64_t
>(
const
int64_t
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
int32_t
>(
const
int32_t
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
uint8_t
>(
const
uint8_t
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
int8_t
>(
const
int8_t
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
ShareExternalData
<
float16
>(
const
float16
*
data
,
const
std
::
vector
<
int
>
&
shape
,
PlaceType
place
,
DataLayout
layout
);
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
float
>(
float
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
int64_t
>(
int64_t
*
data
)
const
;
template
PD_INFER_DECL
void
Tensor
::
CopyToCpu
<
int32_t
>(
int32_t
*
data
)
const
;
...
...
paddle/fluid/inference/api/paddle_tensor.h
浏览文件 @
1ff1c1e0
...
...
@@ -47,6 +47,8 @@ enum DataType {
enum
class
PlaceType
{
kUNK
=
-
1
,
kCPU
,
kGPU
,
kXPU
,
kNPU
,
kIPU
};
enum
class
DataLayout
{
kUNK
=
-
1
,
kAny
,
kNHWC
,
kNCHW
};
/// \brief Represents an n-dimensional array of values.
/// The Tensor is used to store the input or output of the network.
/// Zero copy means that the tensor supports direct copy of host or device data
...
...
@@ -92,6 +94,17 @@ class PD_INFER_DECL Tensor {
template
<
typename
T
>
void
CopyFromCpu
(
const
T
*
data
);
/// \brief Share the data with tensor data.
/// It's usually used to set the tensor data.
/// \param data The pointer of the data, from which the tensor will share.
/// \param shape The shape of data.
/// \param place The place of data.
/// \param layout The layout of data. Only NCHW is supported now.
template
<
typename
T
>
void
ShareExternalData
(
const
T
*
data
,
const
std
::
vector
<
int
>&
shape
,
PlaceType
place
,
DataLayout
layout
=
DataLayout
::
kNCHW
);
/// \brief Experimental interface.
/// It's usually used to set the input tensor data with Strings data type.
/// \param data The pointer of the data, from which the tensor will copy.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录