提交 6c812306 编写于 作者: N nhzlx

update code for config change

test=develop
上级 5c57e150
...@@ -93,7 +93,9 @@ endif() ...@@ -93,7 +93,9 @@ endif()
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt") set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt")
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz")
endif()
cc_test(test_trt_models SRCS trt_models_tester.cc cc_test(test_trt_models SRCS trt_models_tester.cc
ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models
DEPS paddle_inference_tensorrt_subgraph_engine) DEPS paddle_inference_tensorrt_subgraph_engine)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle { namespace paddle {
using paddle::contrib::MixedRTConfig;
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
...@@ -32,8 +33,8 @@ NativeConfig GetConfigNative() { ...@@ -32,8 +33,8 @@ NativeConfig GetConfigNative() {
return config; return config;
} }
TensorRTConfig GetConfigTRT() { MixedRTConfig GetConfigTRT() {
TensorRTConfig config; MixedRTConfig config;
config.model_dir = FLAGS_dirname; config.model_dir = FLAGS_dirname;
config.use_gpu = true; config.use_gpu = true;
config.fraction_of_gpu_memory = 0.2; config.fraction_of_gpu_memory = 0.2;
...@@ -46,14 +47,14 @@ void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) { ...@@ -46,14 +47,14 @@ void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) {
NativeConfig config0 = GetConfigNative(); NativeConfig config0 = GetConfigNative();
config0.model_dir = model_dirname; config0.model_dir = model_dirname;
TensorRTConfig config1 = GetConfigTRT(); MixedRTConfig config1 = GetConfigTRT();
config1.model_dir = model_dirname; config1.model_dir = model_dirname;
config1.max_batch_size = batch_size; config1.max_batch_size = batch_size;
auto predictor0 = auto predictor0 =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0); CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
auto predictor1 = auto predictor1 =
CreatePaddlePredictor<TensorRTConfig, CreatePaddlePredictor<MixedRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config1); PaddleEngineKind::kAutoMixedTensorRT>(config1);
// Prepare inputs // Prepare inputs
int height = 224; int height = 224;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册