未验证 提交 1d2bd35e 编写于 作者: S Shang Zhizhou 提交者: GitHub

update merge pr #31060(update trt int8 calibrator to IEntropyCalibratorV2) (#31121)

上级 a0fa0d9e
...@@ -34,7 +34,7 @@ namespace tensorrt { ...@@ -34,7 +34,7 @@ namespace tensorrt {
class TensorRTEngine; class TensorRTEngine;
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 {
public: public:
TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers, TRTInt8Calibrator(const std::unordered_map<std::string, size_t>& buffers,
int batch_size, std::string engine_name, int batch_size, std::string engine_name,
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <dirent.h>
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -27,22 +26,6 @@ limitations under the License. */ ...@@ -27,22 +26,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
static int DeleteCache(std::string path) {
DIR* dir = opendir(path.c_str());
if (dir == NULL) return 0;
struct dirent* ptr;
while ((ptr = readdir(dir)) != NULL) {
if (std::strcmp(ptr->d_name, ".") == 0 ||
std::strcmp(ptr->d_name, "..") == 0) {
continue;
} else if (ptr->d_type == 8) {
std::string file_rm = path + "/" + ptr->d_name;
return remove(file_rm.c_str());
}
}
return 0;
}
static void run(const AnalysisConfig& config, std::vector<float>* out_data) { static void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames(); auto input_names = predictor->GetInputNames();
...@@ -111,7 +94,7 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -111,7 +94,7 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
// Delete serialization cache to perform serialization first rather than // Delete serialization cache to perform serialization first rather than
// deserialization. // deserialization.
std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache"; std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache";
DeleteCache(opt_cache_dir); delete_cache_files(opt_cache_dir);
SetConfig(&config, model_dir, true /* use_gpu */); SetConfig(&config, model_dir, true /* use_gpu */);
......
...@@ -23,6 +23,9 @@ namespace inference { ...@@ -23,6 +23,9 @@ namespace inference {
TEST(TensorRT, split_converter) { TEST(TensorRT, split_converter) {
std::string model_dir = FLAGS_infer_model + "/split_converter"; std::string model_dir = FLAGS_infer_model + "/split_converter";
std::string opt_cache_dir = model_dir + "/_opt_cache";
delete_cache_files(opt_cache_dir);
AnalysisConfig config; AnalysisConfig config;
int batch_size = 4; int batch_size = 4;
config.EnableUseGpu(100, 0); config.EnableUseGpu(100, 0);
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <dirent.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -134,5 +135,20 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) { ...@@ -134,5 +135,20 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) {
} }
} }
void delete_cache_files(std::string path) {
DIR* dir = opendir(path.c_str());
if (dir == NULL) return;
struct dirent* ptr;
while ((ptr = readdir(dir)) != NULL) {
if (std::strcmp(ptr->d_name, ".") == 0 ||
std::strcmp(ptr->d_name, "..") == 0) {
continue;
} else if (ptr->d_type == 8) {
std::string file_rm = path + "/" + ptr->d_name;
remove(file_rm.c_str());
}
}
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册