未验证 提交 9cbc1eff 编写于 作者: P Pei Yang 提交者: GitHub

zerocopytensor support uint8, analysis config support profile, analysis...

zerocopytensor support uint8, analysis config support profile, analysis predictor support GetInputTensorShape, test=develop (#19822)
上级 00efd1d8
......@@ -130,6 +130,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(anakin_passes_filter_);
CP_MEMBER(anakin_ops_filter_);
// profile related.
CP_MEMBER(with_profile_);
// Ir related.
CP_MEMBER(enable_ir_optim_);
CP_MEMBER(use_feed_fetch_ops_);
......@@ -255,6 +258,7 @@ void AnalysisConfig::Update() {
} else {
pass_builder_.reset(new CpuPassStrategy);
}
} else {
if (use_gpu()) {
pass_builder_.reset(new GpuPassStrategy(
......@@ -272,7 +276,6 @@ void AnalysisConfig::Update() {
pass_builder()->AppendPass(pass);
}
}
if (use_gpu() && use_cudnn_) {
#ifdef PADDLE_WITH_CUDA
if (!enable_ir_optim_) {
......@@ -381,6 +384,8 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << use_mkldnn_quantizer_;
ss << model_from_memory_;
ss << with_profile_;
ss << enable_ir_optim_;
ss << use_feed_fetch_ops_;
ss << ir_debug_;
......@@ -455,6 +460,12 @@ void AnalysisConfig::SwitchIrDebug(int x) {
ir_debug_ = x;
Update();
}
void AnalysisConfig::EnableProfile() {
with_profile_ = true;
Update();
}
void AnalysisConfig::EnableAnakinEngine(
int max_batch_size, std::map<std::string, std::vector<int>> max_input_shape,
int min_subgraph_size, AnalysisConfig::Precision precision_mode,
......
......@@ -52,8 +52,6 @@
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#endif
DECLARE_bool(profile);
namespace paddle {
using inference::Singleton;
......@@ -79,12 +77,14 @@ bool AnalysisPredictor::Init(
const std::shared_ptr<framework::Scope> &parent_scope,
const std::shared_ptr<framework::ProgramDesc> &program) {
VLOG(3) << "Predictor::init()";
if (FLAGS_profile) {
LOG(WARNING) << "Profiler is actived, might affect the performance";
LOG(INFO) << "You can turn off by set gflags '-profile false'";
if (config_.with_profile_) {
LOG(WARNING) << "Profiler is activated, which might affect the performance";
auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll
: platform::ProfilerState::kCPU;
platform::EnableProfiler(tracking_device);
} else {
LOG(INFO) << "Profiler is deactivated, and no profiling report will be "
"generated.";
}
// no matter with or without MKLDNN
......@@ -472,7 +472,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// when the predictor settings are complete, we release these stores.
argument_.PartiallyRelease();
config_.PartiallyRelease();
LOG(INFO) << "== optimize end ==";
LOG(INFO) << "======= optimize end =======";
}
template <>
......@@ -498,7 +498,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
}
if (fraction_of_gpu_memory >= 0.0f || fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
flags.push_back("dummy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(fraction_of_gpu_memory);
flags.push_back(flag);
......@@ -576,6 +576,18 @@ std::vector<std::string> AnalysisPredictor::GetInputNames() {
return input_names;
}
std::map<std::string, std::vector<int64_t>>
AnalysisPredictor::GetInputTensorShape() {
std::map<std::string, std::vector<int64_t>> input_shapes;
std::vector<std::string> names = GetInputNames();
for (std::string name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "input %s does not exist.", name);
input_shapes[name] = var->GetShape();
}
return input_shapes;
}
std::vector<std::string> AnalysisPredictor::GetOutputNames() {
std::vector<std::string> output_names;
for (auto &item : idx2fetches_) {
......@@ -792,7 +804,7 @@ AnalysisPredictor::~AnalysisPredictor() {
SaveTrtCalibToDisk();
}
#endif
if (FLAGS_profile) {
if (config_.with_profile_) {
platform::DisableProfiler(platform::EventSortingKey::kTotal,
"./profile.log");
}
......
......@@ -65,6 +65,8 @@ class AnalysisPredictor : public PaddlePredictor {
std::unique_ptr<ZeroCopyTensor> GetOutputTensor(
const std::string &name) override;
std::map<std::string, std::vector<int64_t>> GetInputTensorShape() override;
bool ZeroCopyRun() override;
void CreateFeedFetchVar(framework::Scope *scope);
......
......@@ -43,6 +43,10 @@ void ZeroCopyTensor::Reshape(const std::vector<int> &shape) {
template <typename T>
T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
EAGER_GET_TENSOR;
PADDLE_ENFORCE_GT(
tensor->numel(), 0,
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
"function before retrieving mutable_data from input tensor.");
switch (static_cast<int>(place)) {
case static_cast<int>(PaddlePlace::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
......@@ -83,8 +87,8 @@ PaddleDType ZeroCopyTensor::type() const {
return PaddleDType::INT64;
} else if (type == framework::proto::VarType::INT32) {
return PaddleDType::INT32;
} else {
LOG(ERROR) << "unknown type, only support float32 and int64 now.";
} else if (type == framework::proto::VarType::UINT8) {
return PaddleDType::UINT8;
}
return PaddleDType::FLOAT32;
}
......@@ -95,7 +99,7 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) {
PADDLE_ENFORCE_GE(
tensor->numel(), 0,
"You should call ZeroCopyTensor::Reshape(const std::vector<int> &shape)"
"function before copy data from cpu.");
"function before copying data from cpu.");
size_t ele_size = tensor->numel() * sizeof(T);
if (place_ == PaddlePlace::kCPU) {
......@@ -112,7 +116,7 @@ void ZeroCopyTensor::copy_from_cpu(const T *data) {
memory::Copy(gpu_place, static_cast<void *>(t_data), platform::CPUPlace(),
data, ele_size, dev_ctx->stream());
#else
PADDLE_THROW("Not compile with CUDA, should not reach here.");
PADDLE_THROW("Not compiled with CUDA, should not reach here.");
#endif
}
}
......@@ -143,9 +147,11 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
template void ZeroCopyTensor::copy_from_cpu<float>(const float *data);
template void ZeroCopyTensor::copy_from_cpu<int64_t>(const int64_t *data);
template void ZeroCopyTensor::copy_from_cpu<int32_t>(const int32_t *data);
template void ZeroCopyTensor::copy_from_cpu<uint8_t>(const uint8_t *data);
template void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template void ZeroCopyTensor::copy_to_cpu<uint8_t>(uint8_t *data);
template float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const;
......@@ -153,9 +159,12 @@ template int64_t *ZeroCopyTensor::data<int64_t>(PaddlePlace *place,
int *size) const;
template int32_t *ZeroCopyTensor::data<int32_t>(PaddlePlace *place,
int *size) const;
template uint8_t *ZeroCopyTensor::data<uint8_t>(PaddlePlace *place,
int *size) const;
template float *ZeroCopyTensor::mutable_data<float>(PaddlePlace place);
template int64_t *ZeroCopyTensor::mutable_data<int64_t>(PaddlePlace place);
template int32_t *ZeroCopyTensor::mutable_data<int32_t>(PaddlePlace place);
template uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE(!name_.empty(),
......
......@@ -248,6 +248,16 @@ struct AnalysisConfig {
bool force_update_static_cache = false);
/** Tell whether the memory optimization is activated. */
bool enable_memory_optim() const;
/** \brief Turn on profiling report.
*
* If not turned on, no profiling report will be generateed.
*/
void EnableProfile();
/** A boolean state telling whether the profiler is activated.
*/
bool profile_enabled() const { return with_profile_; }
void SetInValid() const { is_valid_ = false; }
bool is_valid() const { return is_valid_; }
......@@ -316,6 +326,8 @@ struct AnalysisConfig {
int cpu_math_library_num_threads_{1};
bool with_profile_{false};
// A runtime cache, shouldn't be transferred to others.
std::string serialized_info_cache_;
......
......@@ -23,6 +23,7 @@
*/
#include <cassert>
#include <map>
#include <memory>
#include <string>
#include <vector>
......@@ -37,6 +38,7 @@ enum PaddleDType {
FLOAT32,
INT64,
INT32,
UINT8,
// TODO(Superjomn) support more data types if needed.
};
......@@ -149,8 +151,8 @@ class ZeroCopyTensor {
/** Get the memory in CPU or GPU with specific data type, should Reshape first
* to tell the data size.
* Once can directly call this data to feed the data.
* This is for write the input tensor.
* One can directly call this data to feed the data.
* This is for writing the input tensor.
*/
template <typename T>
T* mutable_data(PaddlePlace place);
......@@ -220,6 +222,12 @@ class PaddlePredictor {
*/
virtual std::vector<std::string> GetInputNames() { return {}; }
/** \brief Get input shapes of the model
*/
virtual std::map<std::string, std::vector<int64_t>> GetInputTensorShape() {
return {};
}
/** \brief Get output names of the model
*/
virtual std::vector<std::string> GetOutputNames() { return {}; }
......
......@@ -269,17 +269,20 @@ download_model_and_data(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_dat
inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert_tester.cc)
if(WITH_GPU AND TENSORRT_FOUND)
set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt")
set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt_tests_models")
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz")
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_inference_test_models.tar.gz")
endif()
inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models)
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_resnet50_test SRCS trt_resnet50_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models)
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_resnext_test SRCS trt_resnext_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models)
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_fc_prelu_test SRCS trt_fc_prelu_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
endif()
......@@ -128,6 +128,14 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
}
break;
}
case PaddleDType::UINT8: {
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
uint8_t *pdata_ref = static_cast<uint8_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
}
}
}
......@@ -172,6 +180,15 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
}
break;
}
case PaddleDType::UINT8: {
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
uint8_t *pdata_ref = ref_out.data<uint8_t>(&place, &ref_size);
EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
}
}
}
......@@ -286,6 +303,8 @@ void ConvertPaddleTensorToZeroCopyTensor(
ZeroCopyTensorAssignData<float>(tensor.get(), input.data);
} else if (input.dtype == PaddleDType::INT32) {
ZeroCopyTensorAssignData<int32_t>(tensor.get(), input.data);
} else if (input.dtype == PaddleDType::UINT8) {
ZeroCopyTensorAssignData<uint8_t>(tensor.get(), input.data);
} else {
LOG(ERROR) << "unsupported feed type " << input.dtype;
}
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(TensorRT_fc, compare) {
std::string model_dir = FLAGS_infer_model + "/fc_uint8";
compare(model_dir, /* use_tensorrt */ true);
// Open it when need.
// profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt);
}
TEST(ZeroCopyTensor, uint8) {
std::string model_dir = FLAGS_infer_model + "/" + "fc_uint8";
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.SwitchUseFeedFetchOps(false);
config.EnableProfile();
std::vector<std::vector<PaddleTensor>> inputs_all;
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
auto name2shape = predictor->GetInputTensorShape();
int batch_size = 1;
int length = 4;
int input_num = batch_size * length;
uint8_t *input = new uint8_t[input_num];
memset(input, 1, input_num * sizeof(uint8_t));
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({batch_size, length});
input_t->copy_from_cpu(input);
input_t->type();
ASSERT_TRUE(predictor->ZeroCopyRun());
}
} // namespace inference
} // namespace paddle
......@@ -337,6 +337,7 @@ void BindAnalysisConfig(py::module *m) {
py::arg("x") = true)
.def("ir_optim", &AnalysisConfig::ir_optim)
.def("enable_memory_optim", &AnalysisConfig::EnableMemoryOptim)
.def("enable_profile", &AnalysisConfig::EnableProfile)
.def("set_optim_cache_dir", &AnalysisConfig::SetOptimCacheDir)
.def("switch_use_feed_fetch_ops", &AnalysisConfig::SwitchUseFeedFetchOps,
py::arg("x") = true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册