From c65ef07c273e74cba5a54ae1a5a9e41fcbea28f4 Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Fri, 18 Aug 2023 14:24:37 +0800 Subject: [PATCH] [Inference] Make share_external_data supports bf16 and bool; fix while_op cache_inference_while_scope when using fleet_executor. (#56055) * 1. make share_external_data supports bf16 and bool; 2. don't drop_kids when cache_inference_while_scope * fix FLAGS_cache_inference_while_scope * add unitest * add unitest * skip unitest when cudnn_version < 8100 * skip test share_external_data_bf16 when CUDA_ARCH < 80 --- .../distributed/fleet_executor/carrier.cc | 20 +++-- .../fluid/inference/api/analysis_predictor.cc | 14 +++- .../inference/api/details/zero_copy_tensor.cc | 30 ++++++++ .../inference/api/paddle_infer_contrib.cc | 20 ++++- paddle/fluid/inference/api/paddle_tensor.h | 1 + .../fluid/operators/controlflow/while_op.cc | 6 +- paddle/fluid/pybind/inference_api.cc | 24 +++++- paddle/phi/core/flags.cc | 5 ++ test/legacy_test/test_inference_api.py | 77 +++++++++++++++++++ 9 files changed, 178 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 82d99a38352..1dc29493af9 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -35,6 +35,7 @@ PADDLE_DEFINE_EXPORTED_bool( "Use standalone executor to run ops. Temporary FLAGS, will be removed " "after all fleet executor cases are modified to run ops with standalone " "executor."); +PHI_DECLARE_bool(cache_inference_while_scope); namespace paddle { namespace distributed { @@ -194,14 +195,17 @@ void Carrier::Start() { // TODO(wangxi): async step Wait(); dev_ctx_->Wait(); - for (auto* micro_scope : microbatch_scopes_) { - // By default, we should delete all kid scopes after run executor because - // some operators may create local scope when running, such as while_op. - // But when while_op also create a local executor to run it's sub block, - // the sub scopes it created should not be dropped immediately, because - // while_grad_op will use some variables created during while_op run, so - // we need to keep the kids and wait for the outer executor to drop them. - micro_scope->DropKids(); + if (!FLAGS_cache_inference_while_scope) { + // don't drop_kids when cache_inference_while_scope + for (auto* micro_scope : microbatch_scopes_) { + // By default, we should delete all kid scopes after run executor because + // some operators may create local scope when running, such as while_op. + // But when while_op also create a local executor to run it's sub block, + // the sub scopes it created should not be dropped immediately, because + // while_grad_op will use some variables created during while_op run, so + // we need to keep the kids and wait for the outer executor to drop them. + micro_scope->DropKids(); + } } } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 2b0fe1dacbe..e7d5028c14f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -226,6 +226,8 @@ bool PaddleTensorToDenseTensor(const PaddleTensor &pt, input_ptr = t->mutable_data(ddim, place); } else if (pt.dtype == PaddleDType::FLOAT16) { input_ptr = t->mutable_data(ddim, place); + } else if (pt.dtype == PaddleDType::BFLOAT16) { + input_ptr = t->mutable_data(ddim, place); } else { LOG(ERROR) << "unsupported feed type " << pt.dtype; return false; @@ -1318,9 +1320,13 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, } else if (type == framework::proto::VarType::FP16) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT16; + } else if (type == framework::proto::VarType::BF16) { + GetFetchOne(fetch, output); + output->dtype = PaddleDType::BFLOAT16; } else { - LOG(ERROR) << "unknown type, only support float32, float16, int64 and " - "int32 now."; + LOG(ERROR) + << "unknown type, only support float32, float16, bfloat16, int64 and " + "int32 now."; } } return true; @@ -1881,6 +1887,8 @@ AnalysisPredictor::GetInputTypes() { input_type[name] = paddle_infer::DataType::FLOAT32; } else if (dtype == paddle::framework::proto::VarType::FP16) { input_type[name] = paddle_infer::DataType::FLOAT16; + } else if (dtype == paddle::framework::proto::VarType::BF16) { + input_type[name] = paddle_infer::DataType::BFLOAT16; } else if (dtype == paddle::framework::proto::VarType::INT64) { input_type[name] = paddle_infer::DataType::INT64; } else if (dtype == paddle::framework::proto::VarType::INT32) { @@ -1938,6 +1946,8 @@ AnalysisPredictor::GetOutputTypes() { output_type[name] = paddle_infer::DataType::FLOAT32; } else if (dtype == paddle::framework::proto::VarType::FP16) { output_type[name] = paddle_infer::DataType::FLOAT16; + } else if (dtype == paddle::framework::proto::VarType::BF16) { + output_type[name] = paddle_infer::DataType::BFLOAT16; } else if (dtype == paddle::framework::proto::VarType::INT64) { output_type[name] = paddle_infer::DataType::INT64; } else if (dtype == paddle::framework::proto::VarType::INT32) { diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index f8ed9dbc243..193e244f86e 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/allocator.h" #ifdef PADDLE_WITH_ONNXRUNTIME #include "onnxruntime_c_api.h" // NOLINT @@ -31,6 +32,7 @@ namespace paddle_infer { using float16 = paddle::platform::float16; +using bfloat16 = phi::dtype::bfloat16; void Tensor::Reshape(const std::vector &shape) { #ifdef PADDLE_WITH_ONNXRUNTIME @@ -173,6 +175,8 @@ DataType Tensor::type() const { return DataType::FLOAT32; } else if (type == paddle::framework::proto::VarType::FP16) { return DataType::FLOAT16; + } else if (type == paddle::framework::proto::VarType::BF16) { + return DataType::BFLOAT16; } else if (type == paddle::framework::proto::VarType::INT64) { return DataType::INT64; } else if (type == paddle::framework::proto::VarType::INT32) { @@ -284,6 +288,11 @@ struct DataTypeInfo { phi::DataType TYPE = phi::DataType::FLOAT16; }; +template <> +struct DataTypeInfo { + phi::DataType TYPE = phi::DataType::BFLOAT16; +}; + template <> struct DataTypeInfo { phi::DataType TYPE = phi::DataType::INT64; @@ -502,6 +511,7 @@ template PD_INFER_DECL void Tensor::CopyFromCpu(const int32_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const uint8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const int8_t *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const float16 *data); +template PD_INFER_DECL void Tensor::CopyFromCpu(const bfloat16 *data); template PD_INFER_DECL void Tensor::CopyFromCpu(const bool *data); template PD_INFER_DECL void Tensor::ShareExternalData( @@ -539,6 +549,11 @@ template PD_INFER_DECL void Tensor::ShareExternalData( const std::vector &shape, PlaceType place, DataLayout layout); +template PD_INFER_DECL void Tensor::ShareExternalData( + const bfloat16 *data, + const std::vector &shape, + PlaceType place, + DataLayout layout); template PD_INFER_DECL void Tensor::ShareExternalData( const bool *data, const std::vector &shape, @@ -552,6 +567,7 @@ template PD_INFER_DECL void Tensor::CopyToCpu(int32_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(uint8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(int8_t *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(float16 *data) const; +template PD_INFER_DECL void Tensor::CopyToCpu(bfloat16 *data) const; template PD_INFER_DECL void Tensor::CopyToCpu(bool *data) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( @@ -570,6 +586,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuImpl( int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl( float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; +template PD_INFER_DECL void Tensor::CopyToCpuImpl( + bfloat16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuImpl(bool *data, void *exec_stream, CallbackFunc cb, @@ -589,6 +607,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, void *exec_stream) const; +template PD_INFER_DECL void Tensor::CopyToCpuAsync( + bfloat16 *data, void *exec_stream) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( bool *data, void *exec_stream) const; @@ -606,6 +626,8 @@ template PD_INFER_DECL void Tensor::CopyToCpuAsync( int8_t *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync( float16 *data, CallbackFunc cb, void *cb_params) const; +template PD_INFER_DECL void Tensor::CopyToCpuAsync( + bfloat16 *data, CallbackFunc cb, void *cb_params) const; template PD_INFER_DECL void Tensor::CopyToCpuAsync(bool *data, CallbackFunc cb, void *cb_params) const; @@ -624,6 +646,8 @@ template PD_INFER_DECL int8_t *Tensor::data(PlaceType *place, int *size) const; template PD_INFER_DECL float16 *Tensor::data(PlaceType *place, int *size) const; +template PD_INFER_DECL bfloat16 *Tensor::data(PlaceType *place, + int *size) const; template PD_INFER_DECL bool *Tensor::data(PlaceType *place, int *size) const; @@ -634,6 +658,8 @@ template PD_INFER_DECL int32_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL uint8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data(PlaceType place); +template PD_INFER_DECL bfloat16 *Tensor::mutable_data( + PlaceType place); template PD_INFER_DECL bool *Tensor::mutable_data(PlaceType place); Tensor::Tensor(void *scope, const void *device_contexts) @@ -923,6 +949,8 @@ template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream); template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream); +template void InternalUtils::CopyFromCpuWithIoStream( + paddle_infer::Tensor *t, const bfloat16 *data, cudaStream_t stream); template void InternalUtils::CopyFromCpuWithIoStream( paddle_infer::Tensor *t, const bool *data, cudaStream_t stream); @@ -940,6 +968,8 @@ template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream); template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, float16 *data, cudaStream_t stream); +template void InternalUtils::CopyToCpuWithIoStream( + paddle_infer::Tensor *t, bfloat16 *data, cudaStream_t stream); template void InternalUtils::CopyToCpuWithIoStream( paddle_infer::Tensor *t, bool *data, cudaStream_t stream); diff --git a/paddle/fluid/inference/api/paddle_infer_contrib.cc b/paddle/fluid/inference/api/paddle_infer_contrib.cc index 11786b05c30..ae9b39408c2 100644 --- a/paddle/fluid/inference/api/paddle_infer_contrib.cc +++ b/paddle/fluid/inference/api/paddle_infer_contrib.cc @@ -108,9 +108,17 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst, cb, cb_params); break; + case PaddleDType::BFLOAT16: + src.CopyToCpuImpl( + dst.mutable_data(PlaceType::kCPU), + exec_stream, + cb, + cb_params); + break; default: PADDLE_THROW(paddle::platform::errors::Unimplemented( - "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, FLOAT32 and " + "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, BFLOAT16, FLOAT32 " + "and " "FLOAT64 is supported in Tensor. Others not implements")); } // gpu => gpu or cpu => gpu @@ -172,9 +180,17 @@ void TensorUtils::CopyTensorImpl(Tensor* p_dst, src.data(&src_place, &data_size)); data_len = data_size * 2; break; + case PaddleDType::BFLOAT16: + dst_data = static_cast( + dst.mutable_data(PlaceType::kGPU)); + src_data = static_cast( + src.data(&src_place, &data_size)); + data_len = data_size * 2; + break; default: PADDLE_THROW(paddle::platform::errors::Unimplemented( - "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, FLOAT32 and " + "Only INT32, INT64, UINT8, INT8, BOOL, FLOAT16, BFLOAT16, FLOAT32 " + "and " "FLOAT64 is supported in Tensor. Others not implements")); } diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index 3f62540d3e3..9bbb494f91e 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -64,6 +64,7 @@ enum DataType { FLOAT16, BOOL, FLOAT64, + BFLOAT16, // TODO(Inference): support more data types if needed. }; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 87f8e58e384..8cf7dc24c1d 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -24,11 +24,7 @@ #endif #include "paddle/fluid/platform/flags.h" -PADDLE_DEFINE_EXPORTED_bool( - cache_inference_while_scope, - false, - "Cache the scope of the while op to avoid repeated creation of the scope " - "for each iteration and improve inference performance."); +PHI_DECLARE_bool(cache_inference_while_scope); namespace paddle { namespace framework { diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 540b2dfa6be..6af01ded06d 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -271,6 +271,16 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT static_cast(input_tensor.data()), shape, ToPaddleInferPlace(input_tensor.place().GetType())); + } else if (input_tensor.dtype() == phi::DataType::BFLOAT16) { + tensor.ShareExternalData( + static_cast(input_tensor.data()), + shape, + ToPaddleInferPlace(input_tensor.place().GetType())); + } else if (input_tensor.dtype() == phi::DataType::BOOL) { + tensor.ShareExternalData( + static_cast(input_tensor.data()), + shape, + ToPaddleInferPlace(input_tensor.place().GetType())); } else if (input_tensor.dtype() == phi::DataType::INT32) { tensor.ShareExternalData( static_cast(input_tensor.data()), @@ -284,7 +294,7 @@ void PaddleInferShareExternalData(paddle_infer::Tensor &tensor, // NOLINT } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now share_external_data only supports INT32, " - "INT64, FLOAT64, FLOAT32 and FLOAT16.")); + "INT64, FLOAT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL.")); } } @@ -311,6 +321,16 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT paddle_tensor.data()), shape, ToPaddleInferPlace(paddle_tensor.place().GetType())); + } else if (paddle_tensor.dtype() == phi::DataType::BFLOAT16) { + tensor.ShareExternalData( + static_cast(paddle_tensor.data()), + shape, + ToPaddleInferPlace(paddle_tensor.place().GetType())); + } else if (paddle_tensor.dtype() == phi::DataType::BOOL) { + tensor.ShareExternalData( + static_cast(paddle_tensor.data()), + shape, + ToPaddleInferPlace(paddle_tensor.place().GetType())); } else if (paddle_tensor.dtype() == phi::DataType::INT32) { tensor.ShareExternalData( static_cast(paddle_tensor.data()), @@ -324,7 +344,7 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now share_external_data only supports INT32, " - "INT64, FLOAT32 and FLOAT16.")); + "INT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL.")); } } diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index 3cb37db9af3..2470981f661 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1073,6 +1073,11 @@ PHI_DEFINE_EXPORTED_bool( gpugraph_enable_hbm_table_collision_stat, false, "enable hash collisions stat for hbm table, default false"); +PHI_DEFINE_EXPORTED_bool( + cache_inference_while_scope, + false, + "Cache the scope of the while op to avoid repeated creation of the scope " + "for each iteration and improve inference performance."); PHI_DEFINE_EXPORTED_double(gpugraph_hbm_table_load_factor, 0.75, "the load factor of hbm table, default 0.75"); diff --git a/test/legacy_test/test_inference_api.py b/test/legacy_test/test_inference_api.py index fb1a04142c6..1fb190618e7 100644 --- a/test/legacy_test/test_inference_api.py +++ b/test/legacy_test/test_inference_api.py @@ -21,6 +21,7 @@ import numpy as np from paddle import fluid from paddle.fluid.core import PaddleDType, PaddleTensor +from paddle.framework import core from paddle.inference import ( Config, create_predictor, @@ -101,6 +102,36 @@ def get_sample_model(): return serialized_program, serialized_params +def get_sample_model_cuda(data_type): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + data = paddle.static.data( + name="data", shape=[-1, 6, 64, 64], dtype=data_type + ) + data_float = paddle.cast(data, "bfloat16") + res = paddle.static.nn.conv2d( + input=data_float, + num_filters=3, + filter_size=3, + groups=1, + padding=0, + bias_attr=False, + act=None, + ) + exe.run(startup_program) + serialized_program = paddle.static.serialize_program( + data, res, program=main_program + ) + serialized_params = paddle.static.serialize_persistables( + data, res, executor=exe, program=main_program + ) + return serialized_program, serialized_params + + class TestInferenceBaseAPI(unittest.TestCase): def get_config(self, model, params): config = Config() @@ -171,5 +202,51 @@ class TestInferenceBaseAPI(unittest.TestCase): test_paddle_tensor() +@unittest.skipIf( + not core.is_compiled_with_cuda() + or paddle.get_cudnn_version() < 8100 + or paddle.device.cuda.get_device_capability()[0] < 8, + "share_external_data_bf16 requires cudnn >= 8.1 and CUDA_ARCH >= 8", +) +class TestInferenceShareExternalDataAPI(unittest.TestCase): + def get_config(self, model, params): + config = Config() + config.set_model_buffer(model, len(model), params, len(params)) + config.enable_use_gpu(100, 0) + return config + + def test_share_external_data_cuda(self): + def test_paddle_tensor_bf16(): + paddle.set_default_dtype("bfloat16") + program, params = get_sample_model_cuda("bfloat16") + paddle.disable_static() + config = self.get_config(program, params) + predictor = create_predictor(config) + in_names = predictor.get_input_names() + in_handle = predictor.get_input_handle(in_names[0]) + in_data = paddle.to_tensor(np.ones((1, 6, 32, 32)), "bfloat16") + in_handle.share_external_data(in_data) + predictor.run() + paddle.set_default_dtype("float32") + paddle.enable_static() + + def test_paddle_tensor_bool(): + paddle.set_default_dtype("bfloat16") + program, params = get_sample_model_cuda("bool") + paddle.disable_static() + config = self.get_config(program, params) + predictor = create_predictor(config) + in_names = predictor.get_input_names() + in_handle = predictor.get_input_handle(in_names[0]) + in_data = paddle.to_tensor(np.ones((1, 6, 32, 32)), "bool") + in_handle.share_external_data(in_data) + predictor.run() + paddle.set_default_dtype("float32") + paddle.enable_static() + + test_paddle_tensor_bf16() + test_paddle_tensor_bool() + + if __name__ == '__main__': unittest.main() -- GitLab