diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 40ca3e85868fbbba19d81336aed1a8cbee58ec54..cd0fc03852a4d4581b037e940b5a687f229658c6 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -178,6 +178,10 @@ struct Argument { // Scales for variables to be quantized DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale); + + // A set of op types to enable their bfloat16 kernels + DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types, Bfloat16EnabledOpTypes, + std::unordered_set); #endif // Passed from config. diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 4abe293c930e23d4896adb3af25ad0532d95c12c..07f3831110342256867f7597b8eed04b918c431e 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -125,6 +125,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(use_mkldnn_); CP_MEMBER(mkldnn_enabled_op_types_); CP_MEMBER(mkldnn_cache_capacity_); + // Bfloat16 related. + CP_MEMBER(use_mkldnn_bfloat16_); + CP_MEMBER(bfloat16_enabled_op_types_); // Quantization related. CP_MEMBER(use_mkldnn_quantizer_); CP_MEMBER(mkldnn_quantizer_config_); @@ -417,6 +420,8 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << use_mkldnn_quantizer_; ss << use_mkldnn_bfloat16_; + for (auto &item : bfloat16_enabled_op_types_) ss << item; + ss << ";"; ss << model_from_memory_; ss << with_profile_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 98bee2d4bb471a6d8c4c7bf6b07159582dd69280..5dae7368a8e7dff74bede444dc6e147b6ceacddc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -501,6 +501,10 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetQuantizeExcludedOpIds( config_.mkldnn_quantizer_config()->excluded_op_ids()); } + if (config_.use_mkldnn_bfloat16_) { + LOG(INFO) << "Bfloat16 is enabled"; + argument_.SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_); + } #endif auto passes = config_.pass_builder()->AllPasses(); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index b1244e4e3dfdd5e6a627054250e6def2a7c35a89..7ad3aaf1f9d08984e7fe3b91320c6d1e7f28a6ef 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -414,6 +414,14 @@ struct PD_INFER_DECL AnalysisConfig { /// bool mkldnn_bfloat16_enabled() const { return use_mkldnn_bfloat16_; } + /// \brief Specify the operator type list to use Bfloat16 acceleration. + /// + /// \param op_list The operator type list. + /// + void SetBfloat16Op(std::unordered_set op_list) { + bfloat16_enabled_op_types_ = op_list; + } + /// /// \brief A boolean state telling whether the thread local CUDA stream is /// enabled. @@ -606,6 +614,7 @@ struct PD_INFER_DECL AnalysisConfig { bool use_mkldnn_quantizer_{false}; std::shared_ptr mkldnn_quantizer_config_; bool use_mkldnn_bfloat16_{false}; + std::unordered_set bfloat16_enabled_op_types_; // If the config is already used on a predictor, it becomes invalid. // Any config can only be used with one predictor. diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 5d6970fc4e3856a1945dfcc407b2d16b5032d3df..1907bb93ccbfbcdb127e8b28de26fb499ab170b4 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -21,6 +21,12 @@ function(download_int8_data install_dir data_file) endif() endfunction() +function(download_bfloat16_data install_dir data_file) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/bfloat16 ${data_file}) + endif() +endfunction() + function(download_GRU_data install_dir data_file) if (NOT EXISTS ${install_dir}/${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/gru ${data_file}) @@ -69,6 +75,16 @@ function(inference_analysis_api_int8_test_run_custom_warmup_batch_size TARGET_NA inference_analysis_api_int8_test_run(${TARGET_NAME} ${test_binary} ${model_dir} ${data_path}) endfunction() +function(inference_analysis_api_bfloat16_test_run TARGET_NAME test_binary model_dir data_path) + inference_analysis_test_run(${TARGET_NAME} + COMMAND ${test_binary} + ARGS --infer_model=${model_dir}/model + --infer_data=${data_path} + --batch_size=50 + --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} + --iterations=2) +endfunction() + function(inference_analysis_api_object_dection_int8_test_run TARGET_NAME test_binary model_dir data_path) inference_analysis_test_run(${TARGET_NAME} COMMAND ${test_binary} @@ -346,6 +362,16 @@ if(WITH_MKLDNN) download_int8_data(${INT8_GOOGLENET_MODEL_DIR} "GoogleNet_int8_model.tar.gz" ) inference_analysis_api_int8_test_run_custom_warmup_batch_size(test_analyzer_int8_googlenet ${INT8_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} 10) + ### BFLOAT16 tests + + # build test binary to be used in subsequent tests + set(BF16_IMG_CLASS_TEST_APP "test_analyzer_bfloat16_image_classification") + set(BF16_IMG_CLASS_TEST_APP_SRC "analyzer_bfloat16_image_classification_tester.cc") + inference_analysis_api_test_build(${BF16_IMG_CLASS_TEST_APP} ${BF16_IMG_CLASS_TEST_APP_SRC}) + + # resnet50 bfloat16 + inference_analysis_api_bfloat16_test_run(test_analyzer_bfloat16_resnet50 ${BF16_IMG_CLASS_TEST_APP} ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH}) + ### Object detection models set(PASCALVOC_DATA_PATH "${INT8_DATA_DIR}/pascalvoc_val_head_300.bin") set(INT8_OBJ_DETECT_TEST_APP "test_analyzer_int8_object_detection") diff --git a/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..3621477148fffd343a67047247be846bb6ee652a --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_bfloat16_image_classification_tester.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/inference/api/paddle_analysis_config.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void SetConfig(AnalysisConfig *cfg) { + cfg->SetModel(FLAGS_infer_model); + cfg->DisableGpu(); + cfg->SwitchIrOptim(); + cfg->SwitchSpecifyInputNames(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_num_threads); + cfg->EnableMKLDNN(); +} + +TEST(Analyzer_int8_image_classification, bfloat16) { + AnalysisConfig cfg; + SetConfig(&cfg); + + AnalysisConfig q_cfg; + SetConfig(&q_cfg); + + // read data from file and prepare batches with test data + std::vector> input_slots_all; + SetInputs(&input_slots_all); + q_cfg.SwitchIrDebug(); + q_cfg.EnableMkldnnBfloat16(); + q_cfg.SetBfloat16Op({"conv2d"}); + CompareBFloat16AndAnalysis(&cfg, &q_cfg, input_slots_all); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc index 5f2c879fe0a0c755d192a6be34ac6a1173412b06..6bfa8a821ae8cf6ef4b1fa33d8ae790700795e2b 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc @@ -30,123 +30,6 @@ void SetConfig(AnalysisConfig *cfg) { cfg->EnableMKLDNN(); } -template -class TensorReader { - public: - TensorReader(std::ifstream &file, size_t beginning_offset, - std::vector shape, std::string name) - : file_(file), position_(beginning_offset), shape_(shape), name_(name) { - numel_ = std::accumulate(shape_.begin(), shape_.end(), size_t{1}, - std::multiplies()); - } - - PaddleTensor NextBatch() { - PaddleTensor tensor; - tensor.name = name_; - tensor.shape = shape_; - tensor.dtype = GetPaddleDType(); - tensor.data.Resize(numel_ * sizeof(T)); - - file_.seekg(position_); - file_.read(static_cast(tensor.data.data()), numel_ * sizeof(T)); - position_ = file_.tellg(); - - if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; - if (file_.fail()) - throw std::runtime_error(name_ + ": failed reading file."); - - return tensor; - } - - protected: - std::ifstream &file_; - size_t position_; - std::vector shape_; - std::string name_; - size_t numel_; -}; - -std::shared_ptr> GetWarmupData( - const std::vector> &test_data, - int num_images = FLAGS_warmup_batch_size) { - int test_data_batch_size = test_data[0][0].shape[0]; - auto iterations = test_data.size(); - auto all_test_data_size = iterations * test_data_batch_size; - PADDLE_ENFORCE_LE(static_cast(num_images), all_test_data_size, - platform::errors::InvalidArgument( - "The requested quantization warmup data size must be " - "lower or equal to the test data size. But received" - "warmup size is %d and test data size is %d. Please " - "use --warmup_batch_size parameter to set smaller " - "warmup batch size.", - num_images, all_test_data_size)); - - PaddleTensor images; - images.name = "image"; - images.shape = {num_images, 3, 224, 224}; - images.dtype = PaddleDType::FLOAT32; - images.data.Resize(sizeof(float) * num_images * 3 * 224 * 224); - - PaddleTensor labels; - labels.name = "label"; - labels.shape = {num_images, 1}; - labels.dtype = PaddleDType::INT64; - labels.data.Resize(sizeof(int64_t) * num_images); - - for (int i = 0; i < num_images; i++) { - auto batch = i / test_data_batch_size; - auto element_in_batch = i % test_data_batch_size; - std::copy_n(static_cast(test_data[batch][0].data.data()) + - element_in_batch * 3 * 224 * 224, - 3 * 224 * 224, - static_cast(images.data.data()) + i * 3 * 224 * 224); - - std::copy_n(static_cast(test_data[batch][1].data.data()) + - element_in_batch, - 1, static_cast(labels.data.data()) + i); - } - - auto warmup_data = std::make_shared>(2); - (*warmup_data)[0] = std::move(images); - (*warmup_data)[1] = std::move(labels); - return warmup_data; -} - -void SetInput(std::vector> *inputs, - int32_t batch_size = FLAGS_batch_size) { - std::ifstream file(FLAGS_infer_data, std::ios::binary); - if (!file) { - FAIL() << "Couldn't open file: " << FLAGS_infer_data; - } - - int64_t total_images{0}; - file.read(reinterpret_cast(&total_images), sizeof(total_images)); - LOG(INFO) << "Total images in file: " << total_images; - - std::vector image_batch_shape{batch_size, 3, 224, 224}; - std::vector label_batch_shape{batch_size, 1}; - auto images_offset_in_file = static_cast(file.tellg()); - auto labels_offset_in_file = - images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224; - - TensorReader image_reader(file, images_offset_in_file, - image_batch_shape, "image"); - TensorReader label_reader(file, labels_offset_in_file, - label_batch_shape, "label"); - - auto iterations_max = total_images / batch_size; - auto iterations = iterations_max; - if (FLAGS_iterations > 0 && FLAGS_iterations < iterations_max) { - iterations = FLAGS_iterations; - } - for (auto i = 0; i < iterations; i++) { - auto images = image_reader.NextBatch(); - auto labels = label_reader.NextBatch(); - inputs->emplace_back( - std::vector{std::move(images), std::move(labels)}); - } -} - TEST(Analyzer_int8_image_classification, quantization) { AnalysisConfig cfg; SetConfig(&cfg); @@ -156,13 +39,13 @@ TEST(Analyzer_int8_image_classification, quantization) { // read data from file and prepare batches with test data std::vector> input_slots_all; - SetInput(&input_slots_all); + SetInputs(&input_slots_all); if (FLAGS_enable_int8) { // prepare warmup batch from input data read earlier // warmup batch size can be different than batch size std::shared_ptr> warmup_data = - GetWarmupData(input_slots_all); + paddle::inference::GetWarmupData(input_slots_all); // configure quantizer q_cfg.EnableMkldnnQuantizer(); diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 252bca2d5522e18960feaf6b9aba3d2a7f2a089a..db22ba59073bc5025355331f700c0acf0f3918fd 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -17,10 +17,12 @@ #include #include +#include #include #include #include // NOLINT #include +#include #include #ifdef WITH_GPERFTOOLS #include @@ -48,6 +50,7 @@ DEFINE_bool(ernie_large, false, "Test ernie large"); DEFINE_bool(with_accuracy_layer, true, "Calculate the accuracy while label is in the input"); DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction"); +DEFINE_bool(enable_bf16, true, "Enable BF16 type prediction"); DEFINE_bool(enable_int8, true, "Enable INT8 type prediction"); DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup"); // setting iterations to 0 means processing the whole dataset @@ -124,6 +127,123 @@ class Barrier { std::size_t _count; }; +template +class TensorReader { + public: + TensorReader(std::ifstream &file, size_t beginning_offset, + std::vector shape, std::string name) + : file_(file), position_(beginning_offset), shape_(shape), name_(name) { + numel_ = std::accumulate(shape_.begin(), shape_.end(), size_t{1}, + std::multiplies()); + } + + PaddleTensor NextBatch() { + PaddleTensor tensor; + tensor.name = name_; + tensor.shape = shape_; + tensor.dtype = GetPaddleDType(); + tensor.data.Resize(numel_ * sizeof(T)); + + file_.seekg(position_); + file_.read(static_cast(tensor.data.data()), numel_ * sizeof(T)); + position_ = file_.tellg(); + + if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; + if (file_.fail()) + throw std::runtime_error(name_ + ": failed reading file."); + + return tensor; + } + + protected: + std::ifstream &file_; + size_t position_; + std::vector shape_; + std::string name_; + size_t numel_; +}; + +std::shared_ptr> GetWarmupData( + const std::vector> &test_data, + int num_images = FLAGS_warmup_batch_size) { + int test_data_batch_size = test_data[0][0].shape[0]; + auto iterations = test_data.size(); + auto all_test_data_size = iterations * test_data_batch_size; + PADDLE_ENFORCE_LE(static_cast(num_images), all_test_data_size, + platform::errors::InvalidArgument( + "The requested quantization warmup data size must be " + "lower or equal to the test data size. But received" + "warmup size is %d and test data size is %d. Please " + "use --warmup_batch_size parameter to set smaller " + "warmup batch size.", + num_images, all_test_data_size)); + + PaddleTensor images; + images.name = "image"; + images.shape = {num_images, 3, 224, 224}; + images.dtype = PaddleDType::FLOAT32; + images.data.Resize(sizeof(float) * num_images * 3 * 224 * 224); + + PaddleTensor labels; + labels.name = "label"; + labels.shape = {num_images, 1}; + labels.dtype = PaddleDType::INT64; + labels.data.Resize(sizeof(int64_t) * num_images); + + for (int i = 0; i < num_images; i++) { + auto batch = i / test_data_batch_size; + auto element_in_batch = i % test_data_batch_size; + std::copy_n(static_cast(test_data[batch][0].data.data()) + + element_in_batch * 3 * 224 * 224, + 3 * 224 * 224, + static_cast(images.data.data()) + i * 3 * 224 * 224); + + std::copy_n(static_cast(test_data[batch][1].data.data()) + + element_in_batch, + 1, static_cast(labels.data.data()) + i); + } + + auto warmup_data = std::make_shared>(2); + (*warmup_data)[0] = std::move(images); + (*warmup_data)[1] = std::move(labels); + return warmup_data; +} + +void SetInputs(std::vector> *inputs, + int32_t batch_size = FLAGS_batch_size) { + std::ifstream file(FLAGS_infer_data, std::ios::binary); + if (!file) { + FAIL() << "Couldn't open file: " << FLAGS_infer_data; + } + + int64_t total_images{0}; + file.read(reinterpret_cast(&total_images), sizeof(total_images)); + LOG(INFO) << "Total images in file: " << total_images; + + std::vector image_batch_shape{batch_size, 3, 224, 224}; + std::vector label_batch_shape{batch_size, 1}; + auto images_offset_in_file = static_cast(file.tellg()); + auto labels_offset_in_file = + images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224; + + TensorReader image_reader(file, images_offset_in_file, + image_batch_shape, "image"); + TensorReader label_reader(file, labels_offset_in_file, + label_batch_shape, "label"); + + auto iterations_max = total_images / batch_size; + auto iterations = iterations_max; + if (FLAGS_iterations > 0 && FLAGS_iterations < iterations_max) { + iterations = FLAGS_iterations; + } + for (auto i = 0; i < iterations; i++) { + auto images = image_reader.NextBatch(); + auto labels = label_reader.NextBatch(); + inputs->emplace_back( + std::vector{std::move(images), std::move(labels)}); + } +} + // Compare result between two PaddleTensor void CompareResult(const std::vector &outputs, const std::vector &ref_outputs) { @@ -555,10 +675,10 @@ void SummarizePerformance(const char *title, float sample) { << " ms"; } -void SummarizePerformance(float sample_latency_fp32, - float sample_latency_int8) { - if (FLAGS_enable_fp32) SummarizePerformance("FP32", sample_latency_fp32); - if (FLAGS_enable_int8) SummarizePerformance("INT8", sample_latency_int8); +void SummarizePerformance(const char *title_fp32, float sample_latency_fp32, + const char *title, float sample_latency) { + SummarizePerformance(title_fp32, sample_latency_fp32); + SummarizePerformance(title, sample_latency); } float CompareAccuracyOne( @@ -708,11 +828,51 @@ void CompareQuantizedAndAnalysis( TestOneThreadPrediction(qcfg, inputs, &quantized_outputs, true, VarType::INT8, &sample_latency_int8); } - SummarizePerformance(sample_latency_fp32, sample_latency_int8); + SummarizePerformance("FP32", sample_latency_fp32, "INT8", + sample_latency_int8); CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx); } +void CompareBFloat16AndAnalysis( + const AnalysisConfig *config, const AnalysisConfig *qconfig, + const std::vector> &inputs, + const int compared_idx = 1) { + PADDLE_ENFORCE_EQ( + inputs[0][0].shape[0], FLAGS_batch_size, + platform::errors::InvalidArgument( + "Input data has to be packed batch by batch. The batchsize is set to " + "%d, but the real input is packed with batchsize = %d", + FLAGS_batch_size, inputs[0][0].shape[0])); + LOG(INFO) << "FP32 & BF16 prediction run: batch_size " << FLAGS_batch_size; + + LOG(INFO) << "--- FP32 prediction start ---"; + auto *cfg = reinterpret_cast(config); + PrintConfig(cfg, true); + std::vector> analysis_outputs; + float sample_latency_fp32{-1}; + + if (FLAGS_enable_fp32) { + TestOneThreadPrediction(cfg, inputs, &analysis_outputs, true, VarType::FP32, + &sample_latency_fp32); + } + + LOG(INFO) << "--- BF16 prediction start ---"; + auto *qcfg = reinterpret_cast(qconfig); + PrintConfig(qcfg, true); + std::vector> bf16_outputs; + float sample_latency_bf16{-1}; + + if (FLAGS_enable_bf16) { + TestOneThreadPrediction(qcfg, inputs, &bf16_outputs, true, VarType::FP32, + &sample_latency_bf16); + } + SummarizePerformance("FP32", sample_latency_fp32, "BF16", + sample_latency_bf16); + + CompareAccuracy(bf16_outputs, analysis_outputs, compared_idx); +} + void CompareAnalysisAndAnalysis( const AnalysisConfig *config1, const AnalysisConfig *config2, const std::vector> &inputs, @@ -749,7 +909,8 @@ void CompareAnalysisAndAnalysis( TestOneThreadPrediction(cfg2, inputs, &int8_outputs, true, VarType::INT8, &sample_latency_int8); } - SummarizePerformance(sample_latency_fp32, sample_latency_int8); + SummarizePerformance("FP32", sample_latency_fp32, "INT8", + sample_latency_int8); if (with_accuracy_layer) { CompareAccuracy(int8_outputs, analysis_outputs, compared_idx); } diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index c8e5048421ccadd0d2c9a53b434dcb32beef6b28..ac615a2320daa06587b7a64328996330ec8236a3 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -502,6 +502,7 @@ void BindAnalysisConfig(py::module *m) { py::return_value_policy::reference) .def("set_mkldnn_cache_capacity", &AnalysisConfig::SetMkldnnCacheCapacity, py::arg("capacity") = 0) + .def("set_bfloat16_op", &AnalysisConfig::SetBfloat16Op) #endif .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)