未验证 提交 3ae3b864 编写于 作者: P Pei Yang 提交者: GitHub

fix trt_dynamic_shape_ernie_deserialize_test (#27290)

* fix trt_dynamic_shape_ernie_deserialize_test

* support when opt cache dir does not exist
上级 1483ea23
......@@ -480,10 +480,9 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif()
# disable test_trt_dynamic_shape_ernie_ser_deser temporary
#inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
# EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
# ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
endif()
......
......@@ -12,15 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <dirent.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <unistd.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
int DeleteCache(std::string path) {
DIR* dir = opendir(path.c_str());
if (dir == NULL) return 0;
struct dirent* ptr;
while ((ptr = readdir(dir)) != NULL) {
if (std::strcmp(ptr->d_name, ".") == 0 ||
std::strcmp(ptr->d_name, "..") == 0) {
continue;
} else if (ptr->d_type == 8) {
std::string file_rm = path + "/" + ptr->d_name;
return remove(file_rm.c_str());
}
}
return 0;
}
void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
......@@ -86,6 +104,11 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) {
void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config;
std::string model_dir = FLAGS_infer_model;
// Delete serialization cache to perform serialization first rather than
// deserialization.
std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache";
DeleteCache(opt_cache_dir);
SetConfig(&config, model_dir, true /* use_gpu */);
config.SwitchUseFeedFetchOps(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册