未验证 提交 3ae3b864 编写于 作者: P Pei Yang 提交者: GitHub

fix trt_dynamic_shape_ernie_deserialize_test (#27290)

* fix trt_dynamic_shape_ernie_deserialize_test

* support when opt cache dir does not exist
上级 1483ea23
...@@ -480,10 +480,9 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -480,10 +480,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif() endif()
# disable test_trt_dynamic_shape_ernie_ser_deser temporary inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
#inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
# EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
# ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
endif() endif()
......
...@@ -12,15 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <dirent.h>
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <unistd.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h" #include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
int DeleteCache(std::string path) {
DIR* dir = opendir(path.c_str());
if (dir == NULL) return 0;
struct dirent* ptr;
while ((ptr = readdir(dir)) != NULL) {
if (std::strcmp(ptr->d_name, ".") == 0 ||
std::strcmp(ptr->d_name, "..") == 0) {
continue;
} else if (ptr->d_type == 8) {
std::string file_rm = path + "/" + ptr->d_name;
return remove(file_rm.c_str());
}
}
return 0;
}
void run(const AnalysisConfig& config, std::vector<float>* out_data) { void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config); auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames(); auto input_names = predictor->GetInputNames();
...@@ -86,6 +104,11 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) { ...@@ -86,6 +104,11 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) {
void trt_ernie(bool with_fp16, std::vector<float> result) { void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config; AnalysisConfig config;
std::string model_dir = FLAGS_infer_model; std::string model_dir = FLAGS_infer_model;
// Delete serialization cache to perform serialization first rather than
// deserialization.
std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache";
DeleteCache(opt_cache_dir);
SetConfig(&config, model_dir, true /* use_gpu */); SetConfig(&config, model_dir, true /* use_gpu */);
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册