From a5c56d83a1b16482dcaae1db6e0543b1cf355f3f Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Mon, 22 Feb 2021 18:57:28 +0800 Subject: [PATCH] update trt int8 calibrator to IEntropyCalibratorV2 (#31060) * update trt int8 calibrator to IEntropyCalibratorV2 * add delele opt_cache for trt_split_converter_test --- .../inference/tensorrt/trt_int8_calibrator.h | 2 +- ...c_shape_ernie_serialize_deserialize_test.h | 19 +------------------ .../tests/api/trt_split_converter_test.cc | 3 +++ .../inference/tests/api/trt_test_helper.h | 16 ++++++++++++++++ 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h index b4b7ee50dc3..15ae67fa10f 100644 --- a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h +++ b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h @@ -34,7 +34,7 @@ namespace tensorrt { class TensorRTEngine; -struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { public: TRTInt8Calibrator(const std::unordered_map& buffers, int batch_size, std::string engine_name, diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index 40955275f56..86a5223cafe 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include #include #include @@ -27,22 +26,6 @@ limitations under the License. */ namespace paddle { namespace inference { -static int DeleteCache(std::string path) { - DIR* dir = opendir(path.c_str()); - if (dir == NULL) return 0; - struct dirent* ptr; - while ((ptr = readdir(dir)) != NULL) { - if (std::strcmp(ptr->d_name, ".") == 0 || - std::strcmp(ptr->d_name, "..") == 0) { - continue; - } else if (ptr->d_type == 8) { - std::string file_rm = path + "/" + ptr->d_name; - return remove(file_rm.c_str()); - } - } - return 0; -} - static void run(const AnalysisConfig& config, std::vector* out_data) { auto predictor = CreatePaddlePredictor(config); auto input_names = predictor->GetInputNames(); @@ -111,7 +94,7 @@ static void trt_ernie(bool with_fp16, std::vector result) { // Delete serialization cache to perform serialization first rather than // deserialization. std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache"; - DeleteCache(opt_cache_dir); + delete_cache_files(opt_cache_dir); SetConfig(&config, model_dir, true /* use_gpu */); diff --git a/paddle/fluid/inference/tests/api/trt_split_converter_test.cc b/paddle/fluid/inference/tests/api/trt_split_converter_test.cc index 9ae0527bd97..c00b36b520b 100644 --- a/paddle/fluid/inference/tests/api/trt_split_converter_test.cc +++ b/paddle/fluid/inference/tests/api/trt_split_converter_test.cc @@ -23,6 +23,9 @@ namespace inference { TEST(TensorRT, split_converter) { std::string model_dir = FLAGS_infer_model + "/split_converter"; + std::string opt_cache_dir = model_dir + "/_opt_cache"; + delete_cache_files(opt_cache_dir); + AnalysisConfig config; int batch_size = 4; config.EnableUseGpu(100, 0); diff --git a/paddle/fluid/inference/tests/api/trt_test_helper.h b/paddle/fluid/inference/tests/api/trt_test_helper.h index ee3ba63bb2c..1abde733581 100644 --- a/paddle/fluid/inference/tests/api/trt_test_helper.h +++ b/paddle/fluid/inference/tests/api/trt_test_helper.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include @@ -134,5 +135,20 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) { } } +void delete_cache_files(std::string path) { + DIR* dir = opendir(path.c_str()); + if (dir == NULL) return; + struct dirent* ptr; + while ((ptr = readdir(dir)) != NULL) { + if (std::strcmp(ptr->d_name, ".") == 0 || + std::strcmp(ptr->d_name, "..") == 0) { + continue; + } else if (ptr->d_type == 8) { + std::string file_rm = path + "/" + ptr->d_name; + remove(file_rm.c_str()); + } + } +} + } // namespace inference } // namespace paddle -- GitLab