diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index a79560a06dbfe97244929d58dc70ca92c0790e0e..c08a73d0da72feb4e10ac90f8e9254cb38c09a78 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -130,6 +130,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(anakin_passes_filter_); CP_MEMBER(anakin_ops_filter_); + // profile related. + CP_MEMBER(with_profile_); + // Ir related. CP_MEMBER(enable_ir_optim_); CP_MEMBER(use_feed_fetch_ops_); @@ -255,6 +258,7 @@ void AnalysisConfig::Update() { } else { pass_builder_.reset(new CpuPassStrategy); } + } else { if (use_gpu()) { pass_builder_.reset(new GpuPassStrategy( @@ -272,7 +276,6 @@ void AnalysisConfig::Update() { pass_builder()->AppendPass(pass); } } - if (use_gpu() && use_cudnn_) { #ifdef PADDLE_WITH_CUDA if (!enable_ir_optim_) { @@ -381,6 +384,8 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << use_mkldnn_quantizer_; ss << model_from_memory_; + ss << with_profile_; + ss << enable_ir_optim_; ss << use_feed_fetch_ops_; ss << ir_debug_; @@ -455,6 +460,12 @@ void AnalysisConfig::SwitchIrDebug(int x) { ir_debug_ = x; Update(); } + +void AnalysisConfig::EnableProfile() { + with_profile_ = true; + Update(); +} + void AnalysisConfig::EnableAnakinEngine( int max_batch_size, std::map> max_input_shape, int min_subgraph_size, AnalysisConfig::Precision precision_mode, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index df62c1fc9a65b54c87ad638ee752344be9966aea..5cf1942cb27f5b02605eb5e0e8e4681951a826b5 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -52,8 +52,6 @@ #include "paddle/fluid/inference/anakin/convert/op_converter.h" #endif -DECLARE_bool(profile); - namespace paddle { using inference::Singleton; @@ -79,12 +77,14 @@ bool AnalysisPredictor::Init( const std::shared_ptr &parent_scope, const std::shared_ptr &program) { VLOG(3) << "Predictor::init()"; - if (FLAGS_profile) { - LOG(WARNING) << "Profiler is actived, might affect the performance"; - LOG(INFO) << "You can turn off by set gflags '-profile false'"; + if (config_.with_profile_) { + LOG(WARNING) << "Profiler is activated, which might affect the performance"; auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll : platform::ProfilerState::kCPU; platform::EnableProfiler(tracking_device); + } else { + LOG(INFO) << "Profiler is deactivated, and no profiling report will be " + "generated."; } // no matter with or without MKLDNN @@ -472,7 +472,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { // when the predictor settings are complete, we release these stores. argument_.PartiallyRelease(); config_.PartiallyRelease(); - LOG(INFO) << "== optimize end =="; + LOG(INFO) << "======= optimize end ======="; } template <> @@ -498,7 +498,7 @@ std::unique_ptr CreatePaddlePredictor< } if (fraction_of_gpu_memory >= 0.0f || fraction_of_gpu_memory <= 0.95f) { - flags.push_back("dummpy"); + flags.push_back("dummy"); std::string flag = "--fraction_of_gpu_memory_to_use=" + std::to_string(fraction_of_gpu_memory); flags.push_back(flag); @@ -576,6 +576,18 @@ std::vector AnalysisPredictor::GetInputNames() { return input_names; } +std::map> +AnalysisPredictor::GetInputTensorShape() { + std::map> input_shapes; + std::vector names = GetInputNames(); + for (std::string name : names) { + auto *var = inference_program_->Block(0).FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var, "input %s does not exist.", name); + input_shapes[name] = var->GetShape(); + } + return input_shapes; +} + std::vector AnalysisPredictor::GetOutputNames() { std::vector output_names; for (auto &item : idx2fetches_) { @@ -792,7 +804,7 @@ AnalysisPredictor::~AnalysisPredictor() { SaveTrtCalibToDisk(); } #endif - if (FLAGS_profile) { + if (config_.with_profile_) { platform::DisableProfiler(platform::EventSortingKey::kTotal, "./profile.log"); } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 0727c7b908b81e66373c9c2a3885edb51b540018..2426e67749053fde9e5e3055b8fd5fadc6f67889 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -65,6 +65,8 @@ class AnalysisPredictor : public PaddlePredictor { std::unique_ptr GetOutputTensor( const std::string &name) override; + std::map> GetInputTensorShape() override; + bool ZeroCopyRun() override; void CreateFeedFetchVar(framework::Scope *scope); diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 88489078eaef0d9214030edfbe49d31e14c1b88c..59ad2c09c0f94d9657c91879956810ccfacbcb35 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -43,6 +43,10 @@ void ZeroCopyTensor::Reshape(const std::vector &shape) { template T *ZeroCopyTensor::mutable_data(PaddlePlace place) { EAGER_GET_TENSOR; + PADDLE_ENFORCE_GT( + tensor->numel(), 0, + "You should call ZeroCopyTensor::Reshape(const std::vector &shape)" + "function before retrieving mutable_data from input tensor."); switch (static_cast(place)) { case static_cast(PaddlePlace::kCPU): { return tensor->mutable_data(platform::CPUPlace()); @@ -83,8 +87,8 @@ PaddleDType ZeroCopyTensor::type() const { return PaddleDType::INT64; } else if (type == framework::proto::VarType::INT32) { return PaddleDType::INT32; - } else { - LOG(ERROR) << "unknown type, only support float32 and int64 now."; + } else if (type == framework::proto::VarType::UINT8) { + return PaddleDType::UINT8; } return PaddleDType::FLOAT32; } @@ -95,7 +99,7 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) { PADDLE_ENFORCE_GE( tensor->numel(), 0, "You should call ZeroCopyTensor::Reshape(const std::vector &shape)" - "function before copy data from cpu."); + "function before copying data from cpu."); size_t ele_size = tensor->numel() * sizeof(T); if (place_ == PaddlePlace::kCPU) { @@ -112,7 +116,7 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) { memory::Copy(gpu_place, static_cast(t_data), platform::CPUPlace(), data, ele_size, dev_ctx->stream()); #else - PADDLE_THROW("Not compile with CUDA, should not reach here."); + PADDLE_THROW("Not compiled with CUDA, should not reach here."); #endif } } @@ -143,9 +147,11 @@ void ZeroCopyTensor::copy_to_cpu(T *data) { template void ZeroCopyTensor::copy_from_cpu(const float *data); template void ZeroCopyTensor::copy_from_cpu(const int64_t *data); template void ZeroCopyTensor::copy_from_cpu(const int32_t *data); +template void ZeroCopyTensor::copy_from_cpu(const uint8_t *data); template void ZeroCopyTensor::copy_to_cpu(float *data); template void ZeroCopyTensor::copy_to_cpu(int64_t *data); template void ZeroCopyTensor::copy_to_cpu(int32_t *data); +template void ZeroCopyTensor::copy_to_cpu(uint8_t *data); template float *ZeroCopyTensor::data(PaddlePlace *place, int *size) const; @@ -153,9 +159,12 @@ template int64_t *ZeroCopyTensor::data(PaddlePlace *place, int *size) const; template int32_t *ZeroCopyTensor::data(PaddlePlace *place, int *size) const; +template uint8_t *ZeroCopyTensor::data(PaddlePlace *place, + int *size) const; template float *ZeroCopyTensor::mutable_data(PaddlePlace place); template int64_t *ZeroCopyTensor::mutable_data(PaddlePlace place); template int32_t *ZeroCopyTensor::mutable_data(PaddlePlace place); +template uint8_t *ZeroCopyTensor::mutable_data(PaddlePlace place); void *ZeroCopyTensor::FindTensor() const { PADDLE_ENFORCE(!name_.empty(), diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index ec8951faf9852b28f0093588100f61cb64057401..4ab1ca9588c5d11fdf33d46a74d56a2c92252f4e 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -248,6 +248,16 @@ struct AnalysisConfig { bool force_update_static_cache = false); /** Tell whether the memory optimization is activated. */ bool enable_memory_optim() const; + + /** \brief Turn on profiling report. + * + * If not turned on, no profiling report will be generateed. + */ + void EnableProfile(); + /** A boolean state telling whether the profiler is activated. + */ + bool profile_enabled() const { return with_profile_; } + void SetInValid() const { is_valid_ = false; } bool is_valid() const { return is_valid_; } @@ -316,6 +326,8 @@ struct AnalysisConfig { int cpu_math_library_num_threads_{1}; + bool with_profile_{false}; + // A runtime cache, shouldn't be transferred to others. std::string serialized_info_cache_; diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 87f40f09eb9bb552bd246cb39bbbd41abac1c9ac..8c0adfcb0688920163bd8a2f960fa5332ff206e1 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -37,6 +38,7 @@ enum PaddleDType { FLOAT32, INT64, INT32, + UINT8, // TODO(Superjomn) support more data types if needed. }; @@ -149,8 +151,8 @@ class ZeroCopyTensor { /** Get the memory in CPU or GPU with specific data type, should Reshape first * to tell the data size. - * Once can directly call this data to feed the data. - * This is for write the input tensor. + * One can directly call this data to feed the data. + * This is for writing the input tensor. */ template T* mutable_data(PaddlePlace place); @@ -220,6 +222,12 @@ class PaddlePredictor { */ virtual std::vector GetInputNames() { return {}; } + /** \brief Get input shapes of the model + */ + virtual std::map> GetInputTensorShape() { + return {}; + } + /** \brief Get output names of the model */ virtual std::vector GetOutputNames() { return {}; } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 9de67e9ca91d937c736fa907ba1b2e8929617416..997e7f44d9d5da1675282908a4473d649ff334c2 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -269,17 +269,20 @@ download_model_and_data(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_dat inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc) if(WITH_GPU AND TENSORRT_FOUND) - set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt") + set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt_tests_models") if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}) - inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") + inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_inference_test_models.tar.gz") endif() inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models) + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) inference_analysis_test(trt_resnet50_test SRCS trt_resnet50_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models) + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) inference_analysis_test(trt_resnext_test SRCS trt_resnext_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models) + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) + inference_analysis_test(trt_fc_prelu_test SRCS trt_fc_prelu_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) endif() diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index f502e05dce41abe2d6aaa2e4c41fd12a8f4262c0..4b751a1cc94df06aa2760416e367239263684489 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -128,6 +128,14 @@ void CompareResult(const std::vector &outputs, } break; } + case PaddleDType::UINT8: { + uint8_t *pdata = static_cast(out.data.data()); + uint8_t *pdata_ref = static_cast(ref_out.data.data()); + for (size_t j = 0; j < size; ++j) { + EXPECT_EQ(pdata_ref[j], pdata[j]); + } + break; + } } } } @@ -172,6 +180,15 @@ void CompareResult(const std::vector &outputs, } break; } + case PaddleDType::UINT8: { + uint8_t *pdata = static_cast(out.data.data()); + uint8_t *pdata_ref = ref_out.data(&place, &ref_size); + EXPECT_EQ(size, ref_size); + for (size_t j = 0; j < size; ++j) { + EXPECT_EQ(pdata_ref[j], pdata[j]); + } + break; + } } } } @@ -286,6 +303,8 @@ void ConvertPaddleTensorToZeroCopyTensor( ZeroCopyTensorAssignData(tensor.get(), input.data); } else if (input.dtype == PaddleDType::INT32) { ZeroCopyTensorAssignData(tensor.get(), input.data); + } else if (input.dtype == PaddleDType::UINT8) { + ZeroCopyTensorAssignData(tensor.get(), input.data); } else { LOG(ERROR) << "unsupported feed type " << input.dtype; } diff --git a/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc b/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ee75f90b441f7d13cd50908078eaf925332dde6 --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(TensorRT_fc, compare) { + std::string model_dir = FLAGS_infer_model + "/fc_uint8"; + compare(model_dir, /* use_tensorrt */ true); + // Open it when need. + // profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt); +} + +TEST(ZeroCopyTensor, uint8) { + std::string model_dir = FLAGS_infer_model + "/" + "fc_uint8"; + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir); + config.SwitchUseFeedFetchOps(false); + config.EnableProfile(); + + std::vector> inputs_all; + auto predictor = CreatePaddlePredictor(config); + auto input_names = predictor->GetInputNames(); + auto name2shape = predictor->GetInputTensorShape(); + + int batch_size = 1; + int length = 4; + int input_num = batch_size * length; + uint8_t *input = new uint8_t[input_num]; + memset(input, 1, input_num * sizeof(uint8_t)); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({batch_size, length}); + input_t->copy_from_cpu(input); + input_t->type(); + + ASSERT_TRUE(predictor->ZeroCopyRun()); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 812fa9db1af7d404870ceb618fe7fa75426498d8..f7a590222854c275acbeb995aa62a36224ccab2e 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -337,6 +337,7 @@ void BindAnalysisConfig(py::module *m) { py::arg("x") = true) .def("ir_optim", &AnalysisConfig::ir_optim) .def("enable_memory_optim", &AnalysisConfig::EnableMemoryOptim) + .def("enable_profile", &AnalysisConfig::EnableProfile) .def("set_optim_cache_dir", &AnalysisConfig::SetOptimCacheDir) .def("switch_use_feed_fetch_ops", &AnalysisConfig::SwitchUseFeedFetchOps, py::arg("x") = true)