From 24f55aedaa991e8ab00d301e568b17c0fae7200a Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Wed, 23 Feb 2022 11:12:15 +0800 Subject: [PATCH] [IPU] update inference demos (#39792) * update inference part * restore white space --- paddle/fluid/inference/CMakeLists.txt | 5 +- paddle/fluid/inference/analysis/argument.h | 8 +- .../analysis/passes/ir_graph_build_pass.cc | 18 +- paddle/fluid/inference/api/analysis_config.cc | 37 +++- .../fluid/inference/api/analysis_predictor.cc | 18 +- .../inference/api/paddle_analysis_config.h | 47 +++-- paddle/fluid/inference/api/paddle_tensor.h | 2 +- .../fluid/inference/tests/api/CMakeLists.txt | 25 ++- .../tests/api/analyzer_ernie_tester.h | 3 +- .../tests/api/ipu_ernie_fp16_test.cc | 184 ++++++++++++++++ .../inference/tests/api/ipu_ernie_test.cc | 196 ++++++++++++++++++ .../tests/api/ipu_multi_model_profile.cc | 105 ++++++++++ .../tests/api/ipu_resnet50_fp16_test.cc | 86 ++++++++ .../inference/tests/api/ipu_resnet50_test.cc | 10 +- .../tests/api/ipu_word2vec_sample.cc | 81 ++++++++ .../fluid/inference/tests/api/tester_helper.h | 52 +++++ 16 files changed, 823 insertions(+), 54 deletions(-) create mode 100644 paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc create mode 100644 paddle/fluid/inference/tests/api/ipu_ernie_test.cc create mode 100644 paddle/fluid/inference/tests/api/ipu_multi_model_profile.cc create mode 100644 paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc create mode 100644 paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index d731bfe139b..887bd52bae5 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -48,11 +48,10 @@ set(STATIC_INFERENCE_API paddle_inference_api analysis_predictor #TODO(wilber, T8T9): Do we still need to support windows gpu static library? if(WIN32 AND WITH_GPU) cc_library(paddle_inference DEPS ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API} ${utils_modules}) +elseif(WITH_IPU) + cc_library(paddle_inference DEPS ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API} ${utils_modules} paddle_ipu) else() create_static_lib(paddle_inference ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API} ${utils_modules}) - if(WITH_IPU) - target_link_libraries(paddle_inference -Wl,--allow-multiple-definition popart_canonicalization_utils) - endif() endif() if(NOT APPLE) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index f474ccd260e..a5c32164bf1 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -278,10 +278,14 @@ struct Argument { // ipu related DECL_ARGUMENT_FIELD(use_ipu, UseIpu, bool); DECL_ARGUMENT_FIELD(ipu_device_num, IpuDeviceNum, int); + DECL_ARGUMENT_FIELD(ipu_micro_batch_size, IpuMicroBatchSize, int); DECL_ARGUMENT_FIELD(ipu_enable_pipelining, IpuEnablePipelining, bool); DECL_ARGUMENT_FIELD(ipu_batches_per_step, IpuBatchesPerStep, int); - DECL_ARGUMENT_FIELD(ipu_batch_size, IpuBatchSize, int); - DECL_ARGUMENT_FIELD(ipu_need_avg_shard, IpuNeedAvgShard, bool); + DECL_ARGUMENT_FIELD(ipu_enable_fp16, IpuEnableFp16, bool); + DECL_ARGUMENT_FIELD(ipu_replica_num, IpuReplicaNum, int); + DECL_ARGUMENT_FIELD(ipu_available_memory_proportion, + IpuAvailableMemoryProportion, float); + DECL_ARGUMENT_FIELD(ipu_enable_half_partial, IpuEnableHalfPartial, bool); // npu related DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool); diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index fe6a27f8072..321716b1c8a 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -72,17 +72,21 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { if (argument->use_ipu()) { argument->main_graph().SetNotOwned("num_ipus", &argument->ipu_device_num()); - argument->main_graph().SetNotOwned("need_avg_shard", - &argument->ipu_need_avg_shard()); + argument->main_graph().SetNotOwned("micro_batch_size", + &argument->ipu_micro_batch_size()); argument->main_graph().SetNotOwned("enable_pipelining", &argument->ipu_enable_pipelining()); argument->main_graph().SetNotOwned("batches_per_step", &argument->ipu_batches_per_step()); - argument->main_graph().SetNotOwned("batch_size", - &argument->ipu_batch_size()); - } else { - PADDLE_THROW( - platform::errors::Unimplemented("Please compile with WITH_IPU")); + argument->main_graph().SetNotOwned("enable_fp16", + &argument->ipu_enable_fp16()); + argument->main_graph().SetNotOwned("replica_num", + &argument->ipu_replica_num()); + argument->main_graph().SetNotOwned( + "available_memory_proportion", + &argument->ipu_available_memory_proportion()); + argument->main_graph().SetNotOwned("enable_half_partial", + &argument->ipu_enable_half_partial()); } } #endif diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 57e49733b32..fd2ccffae3b 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -142,17 +142,28 @@ void AnalysisConfig::EnableNpu(int device_id) { Update(); } -void AnalysisConfig::EnableIpu(int device_num, bool ipu_enable_pipelining, - int ipu_batches_per_step, int ipu_batch_size, - bool ipu_need_avg_shard) { + +void AnalysisConfig::EnableIpu(int ipu_device_num, int ipu_micro_batch_size, + bool ipu_enable_pipelining, + int ipu_batches_per_step) { enable_ir_optim_ = true; use_ipu_ = true; - ipu_device_num_ = device_num; + ipu_device_num_ = ipu_device_num; + ipu_micro_batch_size_ = ipu_micro_batch_size; ipu_enable_pipelining_ = ipu_enable_pipelining; ipu_batches_per_step_ = ipu_batches_per_step; - ipu_batch_size_ = ipu_batch_size; - ipu_need_avg_shard_ = ipu_need_avg_shard; + + Update(); +} + +void AnalysisConfig::SetIpuConfig(bool ipu_enable_fp16, int ipu_replica_num, + float ipu_available_memory_proportion, + bool ipu_enable_half_partial) { + ipu_enable_fp16_ = ipu_enable_fp16; + ipu_replica_num_ = ipu_replica_num; + ipu_available_memory_proportion_ = ipu_available_memory_proportion; + ipu_enable_half_partial_ = ipu_enable_half_partial; Update(); } @@ -255,10 +266,13 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // ipu related CP_MEMBER(use_ipu_); CP_MEMBER(ipu_device_num_); + CP_MEMBER(ipu_micro_batch_size_); CP_MEMBER(ipu_enable_pipelining_); CP_MEMBER(ipu_batches_per_step_); - CP_MEMBER(ipu_batch_size_); - CP_MEMBER(ipu_need_avg_shard_); + CP_MEMBER(ipu_enable_fp16_); + CP_MEMBER(ipu_replica_num_); + CP_MEMBER(ipu_available_memory_proportion_); + CP_MEMBER(ipu_enable_half_partial_); if (use_gpu_) { PADDLE_ENFORCE_EQ(use_xpu_, false, @@ -684,10 +698,13 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << use_ipu_; ss << ipu_device_num_; + ss << ipu_micro_batch_size_; ss << ipu_enable_pipelining_; ss << ipu_batches_per_step_; - ss << ipu_batch_size_; - ss << ipu_need_avg_shard_; + ss << ipu_enable_fp16_; + ss << ipu_replica_num_; + ss << ipu_available_memory_proportion_; + ss << ipu_enable_half_partial_; return ss.str(); } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 6c005e4b2d6..cd6e3a3c759 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -93,6 +93,8 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, input_ptr = t->mutable_data(ddim, place); } else if (pt.dtype == PaddleDType::INT32) { input_ptr = t->mutable_data(ddim, place); + } else if (pt.dtype == PaddleDType::FLOAT16) { + input_ptr = t->mutable_data(ddim, place); } else { LOG(ERROR) << "unsupported feed type " << pt.dtype; return false; @@ -563,8 +565,12 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, } else if (type == framework::proto::VarType::INT32) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT32; + } else if (type == framework::proto::VarType::FP16) { + GetFetchOne(fetch, output); + output->dtype = PaddleDType::FLOAT16; } else { - LOG(ERROR) << "unknown type, only support float32, int64 and int32 now."; + LOG(ERROR) << "unknown type, only support float32, float16, int64 and " + "int32 now."; } } return true; @@ -662,12 +668,18 @@ void AnalysisPredictor::PrepareArgument() { LOG(INFO) << "Lite subgraph engine is enabled"; } +#ifdef PADDLE_WITH_IPU argument_.SetUseIpu(config_.use_ipu_); argument_.SetIpuDeviceNum(config_.ipu_device_num()); + argument_.SetIpuMicroBatchSize(config_.ipu_micro_batch_size_); argument_.SetIpuEnablePipelining(config_.ipu_enable_pipelining_); argument_.SetIpuBatchesPerStep(config_.ipu_batches_per_step_); - argument_.SetIpuBatchSize(config_.ipu_batch_size_); - argument_.SetIpuNeedAvgShard(config_.ipu_need_avg_shard_); + argument_.SetIpuEnableFp16(config_.ipu_enable_fp16_); + argument_.SetIpuReplicaNum(config_.ipu_replica_num_); + argument_.SetIpuAvailableMemoryProportion( + config_.ipu_available_memory_proportion_); + argument_.SetIpuEnableHalfPartial(config_.ipu_enable_half_partial_); +#endif argument_.SetUseNpu(config_.use_npu_); argument_.SetNPUDeviceId(config_.npu_device_id()); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 4b13ca073bc..180c028c6a6 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -234,20 +234,30 @@ struct PD_INFER_DECL AnalysisConfig { /// /// \brief Turn on IPU. /// - /// \param device_num The number of IPUs. - /// \param ipu_enable_pipelining Enable data pipelining between subgraphs, - /// each subgraph is settled on an IPU. (This feature requires the number of - /// IPUs > 1.) - /// \param ipu_batches_per_step The number of micro_batch_size per run. (This - /// feature requires to enable pipelining.) - /// \param ipu_batch_size The micro_batch_size which is the batch_size in the - /// graph. - /// \param ipu_need_avg_shard Enable the auto graph sharding. (This feature - /// requires the number of IPUs > 1.) - /// - void EnableIpu(int device_num = 1, bool ipu_enable_pipelining = false, - int ipu_batches_per_step = 1, int ipu_batch_size = 1, - bool ipu_need_avg_shard = false); + /// \param ipu_device_num the number of IPUs. + /// \param ipu_micro_batch_size the batch size in the graph, only work with + /// mutable input shapes. + /// \param ipu_enable_pipelining enable pipelining. + /// \param ipu_batches_per_step the number of batches per run in pipelining. + /// + void EnableIpu(int ipu_device_num = 1, int ipu_micro_batch_size = 1, + bool ipu_enable_pipelining = false, + int ipu_batches_per_step = 1); + + /// + /// \brief Set IPU config. + /// + /// \param ipu_enable_fp16 enable fp16. + /// \param ipu_replica_num the number of graph replication. + /// \param ipu_available_memory_proportion the available memory proportion for + /// matmul/conv. + /// \param ipu_enable_half_partial enable fp16 partial for matmul, only work + /// with fp16. + /// + void SetIpuConfig(bool ipu_enable_fp16 = false, int ipu_replica_num = 1, + float ipu_available_memory_proportion = 1.0, + bool ipu_enable_half_partial = false); + /// /// \brief Set XPU device id. /// @@ -876,11 +886,14 @@ struct PD_INFER_DECL AnalysisConfig { // ipu related. bool use_ipu_{false}; int ipu_device_num_{1}; - + int ipu_micro_batch_size_{1}; bool ipu_enable_pipelining_{false}; int ipu_batches_per_step_{1}; - int ipu_batch_size_{1}; - bool ipu_need_avg_shard_{false}; + + bool ipu_enable_fp16_{false}; + int ipu_replica_num_{1}; + float ipu_available_memory_proportion_{1.0}; + bool ipu_enable_half_partial_{false}; // If the config is already used on a predictor, it becomes invalid. // Any config can only be used with one predictor. diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index 24a72a0b9da..81eecbb2c14 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -45,7 +45,7 @@ enum DataType { // TODO(Superjomn) support more data types if needed. }; -enum class PlaceType { kUNK = -1, kCPU, kGPU, kXPU, kNPU }; +enum class PlaceType { kUNK = -1, kCPU, kGPU, kXPU, kNPU, kIPU }; /// \brief Represents an n-dimensional array of values. /// The Tensor is used to store the input or output of the network. diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 9dafd0d17c7..85fe931cf93 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -758,11 +758,30 @@ if(ON_INFER OR WITH_GPU) set_tests_properties(test_analyzer_transformer_profile PROPERTIES TIMEOUT 120) endif() -# IPU if (WITH_IPU) - #resnet50 + #word2vec sample + set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec/word2vec.inference.model") + inference_analysis_test(ipu_word2vec_sample SRCS ipu_word2vec_sample.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${WORD2VEC_INSTALL_DIR}) + + # ERNIE + set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie") + inference_analysis_api_test(ipu_ernie_test ${ERNIE_INSTALL_DIR} ipu_ernie_test.cc + ARGS --warmup=true --repeat=10) + inference_analysis_api_test(ipu_ernie_fp16_test ${ERNIE_INSTALL_DIR} ipu_ernie_fp16_test.cc + ARGS --warmup=true --repeat=10) + + # Resnet50 set(RESNET50_MODEL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/resnet50") inference_analysis_test(ipu_resnet50_test SRCS ipu_resnet50_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${RESNET50_MODEL_DIR} --warmup=true --repeat=1000) + ARGS --infer_model=${RESNET50_MODEL_DIR} --warmup=true --repeat=10) + inference_analysis_test(ipu_resnet50_fp16_test SRCS ipu_resnet50_fp16_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${RESNET50_MODEL_DIR} --warmup=true --repeat=10) + + # Only support Resnet50 and Ernie currently + inference_analysis_api_test(ipu_multi_model_profile SRCS ipu_multi_model_profile.cc + ARGS --model_name="Resnet50" --infer_model=${RESNET50_MODEL_DIR} --warmup=true --repeat=10) endif() diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h index 2582a1cb09e..fffcd38d95a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h @@ -150,8 +150,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false, void SetIpuConfig(AnalysisConfig *cfg, int batch_size = 1) { cfg->SetModel(FLAGS_infer_model); - // num_ipu, enable_pipelining, batches_per_step, batch_size, need_avg_shard - cfg->EnableIpu(4, false, 1, batch_size, true); + cfg->EnableIpu(4, batch_size, false, 1); } } // namespace inference diff --git a/paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc b/paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc new file mode 100644 index 00000000000..fa775bd9a9c --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc @@ -0,0 +1,184 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + // fp32 to fp16 + ConvertFP32toFP16(input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs, + int batch_size = 1) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +void SetConfig(AnalysisConfig *cfg, int batch_size = 1) { + cfg->SetModel(FLAGS_infer_model); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + cfg->EnableIpu(1, batch_size, false); + // ipu_enable_fp16, ipu_replica_num, ipu_available_memory_proportion, + // ipu_enable_half_partial + cfg->SetIpuConfig(true, 1, 1.0, true); +} + +// Compare results +TEST(Analyzer_Ernie_ipu, compare_results) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + auto predictor = CreateTestPredictor( + reinterpret_cast(&cfg), + FLAGS_use_analysis); + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + ConvertFP16toFP32(output); + auto outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *fp32_data = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], fp32_data[j], 5e-3); + } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/ipu_ernie_test.cc b/paddle/fluid/inference/tests/api/ipu_ernie_test.cc new file mode 100644 index 00000000000..e36917c9acd --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_ernie_test.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs, + int batch_size = 1) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +void SetConfig(AnalysisConfig *cfg, int batch_size = 1) { + cfg->SetModel(FLAGS_infer_model); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + cfg->EnableIpu(1, batch_size, false); +} + +void profile() { + AnalysisConfig config; + SetConfig(&config); + + std::vector> outputs; + std::vector> inputs; + LoadInputData(&inputs); + TestPrediction(reinterpret_cast(&config), + inputs, &outputs, FLAGS_num_threads); +} + +// Compare Deterministic result +TEST(Analyzer_Ernie_ipu, compare_determine) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + CompareDeterministic(reinterpret_cast(&cfg), + input_slots_all); +} + +// Compare results +TEST(Analyzer_Ernie_ipu, compare_results) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + auto predictor = CreateTestPredictor( + reinterpret_cast(&cfg), + FLAGS_use_analysis); + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + auto outputs_size = outputs.front().data.length() / (sizeof(float)); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], + static_cast(outputs[0].data.data())[j], + FLAGS_accuracy); + } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/ipu_multi_model_profile.cc b/paddle/fluid/inference/tests/api/ipu_multi_model_profile.cc new file mode 100644 index 00000000000..a225feae4a2 --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_multi_model_profile.cc @@ -0,0 +1,105 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gflags/gflags.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +void ErnieInputData(const int &total_batch_size, const bool enable_fp16, + std::vector *inputs) { + const int input_num = total_batch_size * 128 * 1; + std::vector placeholder_012(input_num, 1); + std::vector placeholder_3(input_num, 1); + + for (int i = 0; i < 4; i++) { + PaddleTensor in; + in.name = "placeholder_" + std::to_string(i); + in.shape = {total_batch_size, 128, 1}; + if (i < 3) { + in.data = PaddleBuf(static_cast(placeholder_012.data()), + input_num * sizeof(int64_t)); + in.dtype = PaddleDType::INT64; + } else { + in.data = PaddleBuf(static_cast(placeholder_3.data()), + input_num * sizeof(float)); + in.dtype = PaddleDType::FLOAT32; + if (enable_fp16) { + ConvertFP32toFP16(in); + } + } + inputs->push_back(std::move(in)); + } +} + +void Resnet50InputData(const int &total_batch_size, const bool enable_fp16, + std::vector *inputs) { + const int input_num = total_batch_size * 3 * 318 * 318; + std::vector input(input_num, 1); + PaddleTensor in; + in.shape = {total_batch_size, 3, 318, 318}; + in.data = + PaddleBuf(static_cast(input.data()), input_num * sizeof(float)); + in.dtype = PaddleDType::FLOAT32; + if (enable_fp16) { + ConvertFP32toFP16(in); + } + inputs->push_back(std::move(in)); +} + +// performance profile +TEST(Analyzer_ipu_fp16, performance_profile) { + AnalysisConfig config; + std::vector inputs; + std::vector> outputs; + + int total_batch_size = FLAGS_ipu_micro_batch_size * FLAGS_ipu_replica_num; + if (FLAGS_ipu_enable_pipelining) { + // if device_num > 1 and pipelining is enabled, the total batch size = + // micro_batch_size * device_num(batches_per_step) * replica_num + total_batch_size = FLAGS_ipu_micro_batch_size * FLAGS_ipu_batches_per_step * + FLAGS_ipu_replica_num; + } + + if (FLAGS_model_name == "Resnet50") { + config.SetModel(FLAGS_infer_model + "/model/model", + FLAGS_infer_model + "/model/params"); + Resnet50InputData(total_batch_size, FLAGS_ipu_enable_fp16, &inputs); + } else if (FLAGS_model_name == "Ernie") { + config.SetModel(FLAGS_infer_model + "/model/"); + ErnieInputData(total_batch_size, FLAGS_ipu_enable_fp16, &inputs); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support Resnet50 and Ernie Currently")); + } + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining, + // ipu_batches_per_step + config.EnableIpu(FLAGS_ipu_device_num, FLAGS_ipu_micro_batch_size, + FLAGS_ipu_enable_pipelining, FLAGS_ipu_batches_per_step); + // ipu_enable_fp16, ipu_replica_num, ipu_available_memory_proportion, + // ipu_enable_half_partial + config.SetIpuConfig(FLAGS_ipu_enable_fp16, FLAGS_ipu_replica_num, + FLAGS_ipu_available_memory_proportion, + FLAGS_ipu_enable_half_partial); + + TestPrediction(reinterpret_cast(&config), + {inputs}, &outputs, 1); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc b/paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc new file mode 100644 index 00000000000..1d69069da07 --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "gflags/gflags.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +// Compare results with 1 batch +TEST(Analyzer_Resnet50_ipu, compare_results_1_batch) { + std::string model_dir = FLAGS_infer_model + "/" + "model"; + AnalysisConfig config; + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 1, false); + // ipu_enable_fp16, ipu_replica_num, ipu_available_memory_proportion, + // ipu_enable_half_partial + config.SetIpuConfig(true, 1, 1.0, true); + config.SetModel(model_dir + "/model", model_dir + "/params"); + + std::vector inputs; + auto predictor = CreatePaddlePredictor(config); + const int batch = 1; + const int channel = 3; + const int height = 318; + const int width = 318; + const int input_num = batch * channel * height * width; + std::vector input(input_num, 1); + + PaddleTensor in; + in.shape = {batch, channel, height, width}; + in.data = + PaddleBuf(static_cast(input.data()), input_num * sizeof(float)); + in.dtype = PaddleDType::FLOAT32; + ConvertFP32toFP16(in); + inputs.emplace_back(in); + + std::vector outputs; + + ASSERT_TRUE(predictor->Run(inputs, &outputs)); + + const std::vector truth_values = { + 127.779f, 738.165f, 1013.22f, -438.17f, 366.401f, 927.659f, + 736.222f, -633.684f, -329.927f, -430.155f, -633.062f, -146.548f, + -1324.28f, -1349.36f, -242.675f, 117.448f, -801.723f, -391.514f, + -404.818f, 454.16f, 515.48f, -133.031f, 69.293f, 590.096f, + -1434.69f, -1070.89f, 307.074f, 400.525f, -316.12f, -587.125f, + -161.056f, 800.363f, -96.4708f, 748.706f, 868.174f, -447.938f, + 112.737f, 1127.2f, 47.4355f, 677.72f, 593.186f, -336.4f, + 551.362f, 397.823f, 78.3979f, -715.398f, 405.969f, 404.256f, + 246.019f, -8.42969f, 131.365f, -648.051f}; + + const size_t expected_size = 1; + EXPECT_EQ(outputs.size(), expected_size); + + auto output = outputs.front(); + ConvertFP16toFP32(output); + auto outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float* fp32_data = reinterpret_cast(output.data.data()); + + for (size_t j = 0; j < outputs_size; j += 10) { + EXPECT_NEAR((fp32_data[j] - truth_values[j / 10]) / truth_values[j / 10], + 0., 9e-2); + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc b/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc index f5e755ab466..5fde8e6a5e1 100644 --- a/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc +++ b/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc @@ -33,9 +33,8 @@ static std::vector truth_values = { TEST(Analyzer_Resnet50_ipu, compare_results_1_batch) { std::string model_dir = FLAGS_infer_model + "/" + "model"; AnalysisConfig config; - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - config.EnableIpu(1, false); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 1, false); config.SetModel(model_dir + "/model", model_dir + "/params"); std::vector inputs; @@ -72,9 +71,8 @@ TEST(Analyzer_Resnet50_ipu, compare_results_1_batch) { TEST(Analyzer_Resnet50_ipu, compare_results_2_batch) { std::string model_dir = FLAGS_infer_model + "/" + "model"; AnalysisConfig config; - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - config.EnableIpu(2, false, 1, 2, 1); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 2, false); config.SetModel(model_dir + "/model", model_dir + "/params"); std::vector inputs; diff --git a/paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc b/paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc new file mode 100644 index 00000000000..d38c5c34163 --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file contains a simple demo for how to take a model for inference with + * IPUs. + * Model: wget -q + * http://paddle-inference-dist.bj.bcebos.com/word2vec.inference.model.tar.gz + */ + +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +DEFINE_string(infer_model, "", "Directory of the inference model."); + +using paddle_infer::Config; +using paddle_infer::Predictor; +using paddle_infer::CreatePredictor; + +void inference(std::string model_path, bool use_ipu, + std::vector *out_data) { + //# 1. Create Predictor with a config. + Config config; + config.SetModel(FLAGS_infer_model); + if (use_ipu) { + // ipu_device_num, ipu_micro_batch_size + config.EnableIpu(1, 4); + } + auto predictor = CreatePredictor(config); + + //# 2. Prepare input/output tensor. + auto input_names = predictor->GetInputNames(); + std::vector data{1, 2, 3, 4}; + // For simplicity, we set all the slots with the same data. + for (auto input_name : input_names) { + auto input_tensor = predictor->GetInputHandle(input_name); + input_tensor->Reshape({4, 1}); + input_tensor->CopyFromCpu(data.data()); + } + + //# 3. Run + predictor->Run(); + + //# 4. Get output. + auto output_names = predictor->GetOutputNames(); + auto output_tensor = predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + out_data->resize(out_num); + output_tensor->CopyToCpu(out_data->data()); +} + +int main(int argc, char *argv[]) { + ::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); + std::vector ipu_result; + std::vector cpu_result; + inference(FLAGS_infer_model, true, &ipu_result); + inference(FLAGS_infer_model, false, &cpu_result); + for (size_t i = 0; i < ipu_result.size(); i++) { + CHECK_NEAR(ipu_result[i], cpu_result[i], 1e-6); + } + LOG(INFO) << "Finished"; +} diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 77fab0a86f8..637fa16e31b 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -76,10 +76,23 @@ DEFINE_int32(cpu_num_threads, 1, "Number of threads for each paddle instance."); DEFINE_bool(fuse_multi_gru, false, "Running the inference program with multi_gru_fuse_pass"); +// ipu related +DEFINE_int32(ipu_micro_batch_size, 1, "micro batch size"); +DEFINE_int32(ipu_device_num, 1, "device num"); +DEFINE_bool(ipu_enable_pipelining, false, "enable pipelining"); +DEFINE_int32(ipu_batches_per_step, 1, + "the number of batches per run in pipelining"); +DEFINE_bool(ipu_enable_fp16, false, "enable fp16"); +DEFINE_int32(ipu_replica_num, 1, "replica num"); +DEFINE_double(ipu_available_memory_proportion, 1.0, + "available memory proportion"); +DEFINE_bool(ipu_enable_half_partial, false, "enable half partial"); + namespace paddle { namespace inference { using paddle::framework::proto::VarType; +using float16 = paddle::platform::float16; template constexpr paddle::PaddleDType GetPaddleDType(); @@ -1060,5 +1073,44 @@ static bool CompareTensor(const framework::LoDTensor &a, return true; } +void ConvertFP32toFP16(paddle::PaddleTensor &tensor // NOLINT + ) { + int num = 1; + for (auto dim : tensor.shape) { + num *= dim; + } + PADDLE_ENFORCE_EQ( + tensor.dtype, PaddleDType::FLOAT32, + platform::errors::InvalidArgument( + "The tensor dtype is not float32, only support float32 as input")); + float *fp32_data = reinterpret_cast(tensor.data.data()); + float16 *fp16_data = new float16[num]; + for (int i = 0; i < num; i++) { + fp16_data[i] = float16(fp32_data[i]); + } + tensor.data = + PaddleBuf(static_cast(fp16_data), num * sizeof(float16)); + tensor.dtype = PaddleDType::FLOAT16; +} + +void ConvertFP16toFP32(paddle::PaddleTensor &tensor // NOLINT + ) { + int num = 1; + for (auto dim : tensor.shape) { + num *= dim; + } + PADDLE_ENFORCE_EQ( + tensor.dtype, PaddleDType::FLOAT16, + platform::errors::InvalidArgument( + "The tensor dtype is not float16, only support float16 as input")); + float16 *fp16_data = reinterpret_cast(tensor.data.data()); + float *fp32_data = new float[num]; + for (int i = 0; i < num; i++) { + fp32_data[i] = static_cast(fp16_data[i]); + } + tensor.data = PaddleBuf(static_cast(fp32_data), num * sizeof(float)); + tensor.dtype = PaddleDType::FLOAT32; +} + } // namespace inference } // namespace paddle -- GitLab