未验证 提交 c65ef07c 编写于 作者: L lzy 提交者: GitHub

[Inference] Make share_external_data supports bf16 and bool; fix while_op...

[Inference] Make share_external_data supports bf16 and bool; fix while_op cache_inference_while_scope when using fleet_executor. (#56055)

* 1. make share_external_data supports bf16 and bool; 2. don't drop_kids when cache_inference_while_scope

* fix FLAGS_cache_inference_while_scope

* add unitest

* add unitest

* skip unitest when cudnn_version < 8100

* skip test share_external_data_bf16 when CUDA_ARCH < 80
上级 7f5c14bc
...@@ -35,6 +35,7 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -35,6 +35,7 @@ PADDLE_DEFINE_EXPORTED_bool(
"Use standalone executor to run ops. Temporary FLAGS, will be removed " "Use standalone executor to run ops. Temporary FLAGS, will be removed "
"after all fleet executor cases are modified to run ops with standalone " "after all fleet executor cases are modified to run ops with standalone "
"executor."); "executor.");
PHI_DECLARE_bool(cache_inference_while_scope);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -194,14 +195,17 @@ void Carrier::Start() { ...@@ -194,14 +195,17 @@ void Carrier::Start() {
// TODO(wangxi): async step // TODO(wangxi): async step
Wait(); Wait();
dev_ctx_->Wait(); dev_ctx_->Wait();
for (auto* micro_scope : microbatch_scopes_) { if (!FLAGS_cache_inference_while_scope) {
// By default, we should delete all kid scopes after run executor because // don't drop_kids when cache_inference_while_scope
// some operators may create local scope when running, such as while_op. for (auto* micro_scope : microbatch_scopes_) {
// But when while_op also create a local executor to run it's sub block, // By default, we should delete all kid scopes after run executor because
// the sub scopes it created should not be dropped immediately, because // some operators may create local scope when running, such as while_op.
// while_grad_op will use some variables created during while_op run, so // But when while_op also create a local executor to run it's sub block,
// we need to keep the kids and wait for the outer executor to drop them. // the sub scopes it created should not be dropped immediately, because
micro_scope->DropKids(); // while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scope->DropKids();
}
} }
} }
......
...@@ -226,6 +226,8 @@ bool PaddleTensorToDenseTensor(const PaddleTensor &pt, ...@@ -226,6 +226,8 @@ bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
input_ptr = t->mutable_data<int32_t>(ddim, place); input_ptr = t->mutable_data<int32_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT16) { } else if (pt.dtype == PaddleDType::FLOAT16) {
input_ptr = t->mutable_data<float16>(ddim, place); input_ptr = t->mutable_data<float16>(ddim, place);
} else if (pt.dtype == PaddleDType::BFLOAT16) {
input_ptr = t->mutable_data<bfloat16>(ddim, place);
} else { } else {
LOG(ERROR) << "unsupported feed type " << pt.dtype; LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false; return false;
...@@ -1318,9 +1320,13 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -1318,9 +1320,13 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
} else if (type == framework::proto::VarType::FP16) { } else if (type == framework::proto::VarType::FP16) {
GetFetchOne<float16>(fetch, output); GetFetchOne<float16>(fetch, output);
output->dtype = PaddleDType::FLOAT16; output->dtype = PaddleDType::FLOAT16;
} else if (type == framework::proto::VarType::BF16) {
GetFetchOne<bfloat16>(fetch, output);
output->dtype = PaddleDType::BFLOAT16;
} else { } else {
LOG(ERROR) << "unknown type, only support float32, float16, int64 and " LOG(ERROR)
"int32 now."; << "unknown type, only support float32, float16, bfloat16, int64 and "
"int32 now.";
} }
} }
return true; return true;
...@@ -1881,6 +1887,8 @@ AnalysisPredictor::GetInputTypes() { ...@@ -1881,6 +1887,8 @@ AnalysisPredictor::GetInputTypes() {
input_type[name] = paddle_infer::DataType::FLOAT32; input_type[name] = paddle_infer::DataType::FLOAT32;
} else if (dtype == paddle::framework::proto::VarType::FP16) { } else if (dtype == paddle::framework::proto::VarType::FP16) {
input_type[name] = paddle_infer::DataType::FLOAT16; input_type[name] = paddle_infer::DataType::FLOAT16;
} else if (dtype == paddle::framework::proto::VarType::BF16) {
input_type[name] = paddle_infer::DataType::BFLOAT16;
} else if (dtype == paddle::framework::proto::VarType::INT64) { } else if (dtype == paddle::framework::proto::VarType::INT64) {
input_type[name] = paddle_infer::DataType::INT64; input_type[name] = paddle_infer::DataType::INT64;
} else if (dtype == paddle::framework::proto::VarType::INT32) { } else if (dtype == paddle::framework::proto::VarType::INT32) {
...@@ -1938,6 +1946,8 @@ AnalysisPredictor::GetOutputTypes() { ...@@ -1938,6 +1946,8 @@ AnalysisPredictor::GetOutputTypes() {
output_type[name] = paddle_infer::DataType::FLOAT32; output_type[name] = paddle_infer::DataType::FLOAT32;
} else if (dtype == paddle::framework::proto::VarType::FP16) { } else if (dtype == paddle::framework::proto::VarType::FP16) {
output_type[name] = paddle_infer::DataType::FLOAT16; output_type[name] = paddle_infer::DataType::FLOAT16;
} else if (dtype == paddle::framework::proto::VarType::BF16) {
output_type[name] = paddle_infer::DataType::BFLOAT16;
} else if (dtype == paddle::framework::proto::VarType::INT64) { } else if (dtype == paddle::framework::proto::VarType::INT64) {
output_type[name] = paddle_infer::DataType::INT64; output_type[name] = paddle_infer::DataType::INT64;
} else if (dtype == paddle::framework::proto::VarType::INT32) { } else if (dtype == paddle::framework::proto::VarType::INT32) {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
#include "onnxruntime_c_api.h" // NOLINT #include "onnxruntime_c_api.h" // NOLINT
...@@ -31,6 +32,7 @@ ...@@ -31,6 +32,7 @@
namespace paddle_infer { namespace paddle_infer {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
using bfloat16 = phi::dtype::bfloat16;
void Tensor::Reshape(const std::vector<int> &shape) { void Tensor::Reshape(const std::vector<int> &shape) {
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
...@@ -173,6 +175,8 @@ DataType Tensor::type() const { ...@@ -173,6 +175,8 @@ DataType Tensor::type() const {
return DataType::FLOAT32; return DataType::FLOAT32;
} else if (type == paddle::framework::proto::VarType::FP16) { } else if (type == paddle::framework::proto::VarType::FP16) {
return DataType::FLOAT16; return DataType::FLOAT16;
} else if (type == paddle::framework::proto::VarType::BF16) {
return DataType::BFLOAT16;
} else if (type == paddle::framework::proto::VarType::INT64) { } else if (type == paddle::framework::proto::VarType::INT64) {
return DataType::INT64; return DataType::INT64;
} else if (type == paddle::framework::proto::VarType::INT32) { } else if (type == paddle::framework::proto::VarType::INT32) {
...@@ -284,6 +288,11 @@ struct DataTypeInfo<float16> { ...@@ -284,6 +288,11 @@ struct DataTypeInfo<float16> {
phi::DataType TYPE = phi::DataType::FLOAT16; phi::DataType TYPE = phi::DataType::FLOAT16;
}; };
template <>
struct DataTypeInfo<bfloat16> {
phi::DataType TYPE = phi::DataType::BFLOAT16;
};
template <> template <>
struct DataTypeInfo<int64_t> { struct DataTypeInfo<int64_t> {
phi::DataType TYPE = phi::DataType::INT64; phi::DataType TYPE = phi::DataType::INT64;
...@@ -502,6 +511,7 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data); ...@@ -502,6 +511,7 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data); template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<bfloat16>(const bfloat16 *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<bool>(const bool *data); template PD_INFER_DECL void Tensor::CopyFromCpu<bool>(const bool *data);
template PD_INFER_DECL void Tensor::ShareExternalData<double>( template PD_INFER_DECL void Tensor::ShareExternalData<double>(
...@@ -539,6 +549,11 @@ template PD_INFER_DECL void Tensor::ShareExternalData<float16>( ...@@ -539,6 +549,11 @@ template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
const std::vector<int> &shape, const std::vector<int> &shape,
PlaceType place, PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<bfloat16>(
const bfloat16 *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<bool>( template PD_INFER_DECL void Tensor::ShareExternalData<bool>(
const bool *data, const bool *data,
const std::vector<int> &shape, const std::vector<int> &shape,
...@@ -552,6 +567,7 @@ template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const; ...@@ -552,6 +567,7 @@ template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const; template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<bfloat16>(bfloat16 *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<bool>(bool *data) const; template PD_INFER_DECL void Tensor::CopyToCpu<bool>(bool *data) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<double>( template PD_INFER_DECL void Tensor::CopyToCpuImpl<double>(
...@@ -570,6 +586,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>( ...@@ -570,6 +586,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>( template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<bfloat16>(
bfloat16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<bool>(bool *data, template PD_INFER_DECL void Tensor::CopyToCpuImpl<bool>(bool *data,
void *exec_stream, void *exec_stream,
CallbackFunc cb, CallbackFunc cb,
...@@ -589,6 +607,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>( ...@@ -589,6 +607,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
int8_t *data, void *exec_stream) const; int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>( template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
float16 *data, void *exec_stream) const; float16 *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bfloat16>(
bfloat16 *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>( template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(
bool *data, void *exec_stream) const; bool *data, void *exec_stream) const;
...@@ -606,6 +626,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>( ...@@ -606,6 +626,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
int8_t *data, CallbackFunc cb, void *cb_params) const; int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>( template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
float16 *data, CallbackFunc cb, void *cb_params) const; float16 *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bfloat16>(
bfloat16 *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(bool *data, template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(bool *data,
CallbackFunc cb, CallbackFunc cb,
void *cb_params) const; void *cb_params) const;
...@@ -624,6 +646,8 @@ template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place, ...@@ -624,6 +646,8 @@ template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
int *size) const; int *size) const;
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place, template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
int *size) const; int *size) const;
template PD_INFER_DECL bfloat16 *Tensor::data<bfloat16>(PlaceType *place,
int *size) const;
template PD_INFER_DECL bool *Tensor::data<bool>(PlaceType *place, template PD_INFER_DECL bool *Tensor::data<bool>(PlaceType *place,
int *size) const; int *size) const;
...@@ -634,6 +658,8 @@ template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place); ...@@ -634,6 +658,8 @@ template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place); template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
template PD_INFER_DECL bfloat16 *Tensor::mutable_data<bfloat16>(
PlaceType place);
template PD_INFER_DECL bool *Tensor::mutable_data<bool>(PlaceType place); template PD_INFER_DECL bool *Tensor::mutable_data<bool>(PlaceType place);
Tensor::Tensor(void *scope, const void *device_contexts) Tensor::Tensor(void *scope, const void *device_contexts)
...@@ -923,6 +949,8 @@ template void InternalUtils::CopyFromCpuWithIoStream<int8_t>( ...@@ -923,6 +949,8 @@ template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream); paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>( template void InternalUtils::CopyFromCpuWithIoStream<float16>(
paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream); paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<bfloat16>(
paddle_infer::Tensor *t, const bfloat16 *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<bool>( template void InternalUtils::CopyFromCpuWithIoStream<bool>(
paddle_infer::Tensor *t, const bool *data, cudaStream_t stream); paddle_infer::Tensor *t, const bool *data, cudaStream_t stream);
...@@ -940,6 +968,8 @@ template void InternalUtils::CopyToCpuWithIoStream<int8_t>( ...@@ -940,6 +968,8 @@ template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream); paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>( template void InternalUtils::CopyToCpuWithIoStream<float16>(
paddle_infer::Tensor *t, float16 *data, cudaStream_t stream); paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<bfloat16>(
paddle_infer::Tensor *t, bfloat16 *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<bool>( template void InternalUtils::CopyToCpuWithIoStream<bool>(
paddle_infer::Tensor *t, bool *data, cudaStream_t stream); paddle_infer::Tensor *t, bool *data, cudaStream_t stream);
......
...@@ -108,9 +108,17 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst, ...@@ -108,9 +108,17 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst,
cb, cb,
cb_params); cb_params);
break; break;
case PaddleDType::BFLOAT16:
src.CopyToCpuImpl(
dst.mutable_data<paddle::platform::bfloat16>(PlaceType::kCPU),
exec_stream,
cb,
cb_params);
break;
default: default:
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, FLOAT32 and " "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, BFLOAT16, FLOAT32 "
"and "
"FLOAT64 is supported in Tensor. Others not implements")); "FLOAT64 is supported in Tensor. Others not implements"));
} }
// gpu => gpu or cpu => gpu // gpu => gpu or cpu => gpu
...@@ -172,9 +180,17 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst, ...@@ -172,9 +180,17 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst,
src.data<paddle::platform::float16>(&src_place, &data_size)); src.data<paddle::platform::float16>(&src_place, &data_size));
data_len = data_size * 2; data_len = data_size * 2;
break; break;
case PaddleDType::BFLOAT16:
dst_data = static_cast<void*>(
dst.mutable_data<paddle::platform::bfloat16>(PlaceType::kGPU));
src_data = static_cast<void*>(
src.data<paddle::platform::bfloat16>(&src_place, &data_size));
data_len = data_size * 2;
break;
default: default:
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, FLOAT32 and " "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, BFLOAT16, FLOAT32 "
"and "
"FLOAT64 is supported in Tensor. Others not implements")); "FLOAT64 is supported in Tensor. Others not implements"));
} }
......
...@@ -64,6 +64,7 @@ enum DataType { ...@@ -64,6 +64,7 @@ enum DataType {
FLOAT16, FLOAT16,
BOOL, BOOL,
FLOAT64, FLOAT64,
BFLOAT16,
// TODO(Inference): support more data types if needed. // TODO(Inference): support more data types if needed.
}; };
......
...@@ -24,11 +24,7 @@ ...@@ -24,11 +24,7 @@
#endif #endif
#include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/flags.h"
PADDLE_DEFINE_EXPORTED_bool( PHI_DECLARE_bool(cache_inference_while_scope);
cache_inference_while_scope,
false,
"Cache the scope of the while op to avoid repeated creation of the scope "
"for each iteration and improve inference performance.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -271,6 +271,16 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT ...@@ -271,6 +271,16 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
static_cast<phi::dtype::float16 *>(input_tensor.data()), static_cast<phi::dtype::float16 *>(input_tensor.data()),
shape, shape,
ToPaddleInferPlace(input_tensor.place().GetType())); ToPaddleInferPlace(input_tensor.place().GetType()));
} else if (input_tensor.dtype() == phi::DataType::BFLOAT16) {
tensor.ShareExternalData(
static_cast<bfloat16 *>(input_tensor.data()),
shape,
ToPaddleInferPlace(input_tensor.place().GetType()));
} else if (input_tensor.dtype() == phi::DataType::BOOL) {
tensor.ShareExternalData(
static_cast<bool *>(input_tensor.data()),
shape,
ToPaddleInferPlace(input_tensor.place().GetType()));
} else if (input_tensor.dtype() == phi::DataType::INT32) { } else if (input_tensor.dtype() == phi::DataType::INT32) {
tensor.ShareExternalData( tensor.ShareExternalData(
static_cast<int32_t *>(input_tensor.data()), static_cast<int32_t *>(input_tensor.data()),
...@@ -284,7 +294,7 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT ...@@ -284,7 +294,7 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now share_external_data only supports INT32, " "Unsupported data type. Now share_external_data only supports INT32, "
"INT64, FLOAT64, FLOAT32 and FLOAT16.")); "INT64, FLOAT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL."));
} }
} }
...@@ -311,6 +321,16 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT ...@@ -311,6 +321,16 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
paddle_tensor.data<paddle::platform::float16>()), paddle_tensor.data<paddle::platform::float16>()),
shape, shape,
ToPaddleInferPlace(paddle_tensor.place().GetType())); ToPaddleInferPlace(paddle_tensor.place().GetType()));
} else if (paddle_tensor.dtype() == phi::DataType::BFLOAT16) {
tensor.ShareExternalData(
static_cast<bfloat16 *>(paddle_tensor.data<bfloat16>()),
shape,
ToPaddleInferPlace(paddle_tensor.place().GetType()));
} else if (paddle_tensor.dtype() == phi::DataType::BOOL) {
tensor.ShareExternalData(
static_cast<bool *>(paddle_tensor.data<bool>()),
shape,
ToPaddleInferPlace(paddle_tensor.place().GetType()));
} else if (paddle_tensor.dtype() == phi::DataType::INT32) { } else if (paddle_tensor.dtype() == phi::DataType::INT32) {
tensor.ShareExternalData( tensor.ShareExternalData(
static_cast<int32_t *>(paddle_tensor.data<int32_t>()), static_cast<int32_t *>(paddle_tensor.data<int32_t>()),
...@@ -324,7 +344,7 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT ...@@ -324,7 +344,7 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now share_external_data only supports INT32, " "Unsupported data type. Now share_external_data only supports INT32, "
"INT64, FLOAT32 and FLOAT16.")); "INT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL."));
} }
} }
......
...@@ -1073,6 +1073,11 @@ PHI_DEFINE_EXPORTED_bool( ...@@ -1073,6 +1073,11 @@ PHI_DEFINE_EXPORTED_bool(
gpugraph_enable_hbm_table_collision_stat, gpugraph_enable_hbm_table_collision_stat,
false, false,
"enable hash collisions stat for hbm table, default false"); "enable hash collisions stat for hbm table, default false");
PHI_DEFINE_EXPORTED_bool(
cache_inference_while_scope,
false,
"Cache the scope of the while op to avoid repeated creation of the scope "
"for each iteration and improve inference performance.");
PHI_DEFINE_EXPORTED_double(gpugraph_hbm_table_load_factor, PHI_DEFINE_EXPORTED_double(gpugraph_hbm_table_load_factor,
0.75, 0.75,
"the load factor of hbm table, default 0.75"); "the load factor of hbm table, default 0.75");
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
from paddle import fluid from paddle import fluid
from paddle.fluid.core import PaddleDType, PaddleTensor from paddle.fluid.core import PaddleDType, PaddleTensor
from paddle.framework import core
from paddle.inference import ( from paddle.inference import (
Config, Config,
create_predictor, create_predictor,
...@@ -101,6 +102,36 @@ def get_sample_model(): ...@@ -101,6 +102,36 @@ def get_sample_model():
return serialized_program, serialized_params return serialized_program, serialized_params
def get_sample_model_cuda(data_type):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
data = paddle.static.data(
name="data", shape=[-1, 6, 64, 64], dtype=data_type
)
data_float = paddle.cast(data, "bfloat16")
res = paddle.static.nn.conv2d(
input=data_float,
num_filters=3,
filter_size=3,
groups=1,
padding=0,
bias_attr=False,
act=None,
)
exe.run(startup_program)
serialized_program = paddle.static.serialize_program(
data, res, program=main_program
)
serialized_params = paddle.static.serialize_persistables(
data, res, executor=exe, program=main_program
)
return serialized_program, serialized_params
class TestInferenceBaseAPI(unittest.TestCase): class TestInferenceBaseAPI(unittest.TestCase):
def get_config(self, model, params): def get_config(self, model, params):
config = Config() config = Config()
...@@ -171,5 +202,51 @@ class TestInferenceBaseAPI(unittest.TestCase): ...@@ -171,5 +202,51 @@ class TestInferenceBaseAPI(unittest.TestCase):
test_paddle_tensor() test_paddle_tensor()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.get_cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"share_external_data_bf16 requires cudnn >= 8.1 and CUDA_ARCH >= 8",
)
class TestInferenceShareExternalDataAPI(unittest.TestCase):
def get_config(self, model, params):
config = Config()
config.set_model_buffer(model, len(model), params, len(params))
config.enable_use_gpu(100, 0)
return config
def test_share_external_data_cuda(self):
def test_paddle_tensor_bf16():
paddle.set_default_dtype("bfloat16")
program, params = get_sample_model_cuda("bfloat16")
paddle.disable_static()
config = self.get_config(program, params)
predictor = create_predictor(config)
in_names = predictor.get_input_names()
in_handle = predictor.get_input_handle(in_names[0])
in_data = paddle.to_tensor(np.ones((1, 6, 32, 32)), "bfloat16")
in_handle.share_external_data(in_data)
predictor.run()
paddle.set_default_dtype("float32")
paddle.enable_static()
def test_paddle_tensor_bool():
paddle.set_default_dtype("bfloat16")
program, params = get_sample_model_cuda("bool")
paddle.disable_static()
config = self.get_config(program, params)
predictor = create_predictor(config)
in_names = predictor.get_input_names()
in_handle = predictor.get_input_handle(in_names[0])
in_data = paddle.to_tensor(np.ones((1, 6, 32, 32)), "bool")
in_handle.share_external_data(in_data)
predictor.run()
paddle.set_default_dtype("float32")
paddle.enable_static()
test_paddle_tensor_bf16()
test_paddle_tensor_bool()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册