diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index d3e1e56c4d8048dcfb88180ca872d0f790cca84a..e8d11370d118c742477d34cc73e177e9352ccd11 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -150,6 +150,11 @@ std::vector Predictor::GetInputNames() { return input_names_; } // get outputnames std::vector Predictor::GetOutputNames() { return output_names_; } +// get param names +std::vector Predictor::GetParamNames() { + return exec_scope_->LocalVarNames(); +} + // append the names of inputs and outputs into input_names_ and output_names_ void Predictor::PrepareFeedFetch() { if (!program_) { diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 0348ab8ef13d95c469f49e069dd6f5fc1f76f07b..a891bd74f966ee0fc40c68fd7ce91f670c2da7db 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -84,6 +84,9 @@ class LITE_API Predictor { // get inputnames and get outputnames. std::vector GetInputNames(); std::vector GetOutputNames(); + // get param names + std::vector GetParamNames(); + void PrepareFeedFetch(); // Get offset-th col of fetch results. @@ -144,6 +147,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { // get inputs names and get outputs names std::vector GetInputNames() override; std::vector GetOutputNames() override; + // get param names + std::vector GetParamNames() override; // get tensor according to tensor's name std::unique_ptr GetTensor( diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 24500dd04318551206a65a33d2e93ce32fcd6029..c58414b3add1330a81c0a5819e4bdfbe9fd3a01b 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -75,6 +75,10 @@ std::vector CxxPaddleApiImpl::GetInputNames() { return raw_predictor_.GetInputNames(); } +std::vector CxxPaddleApiImpl::GetParamNames() { + return raw_predictor_.GetParamNames(); +} + std::vector CxxPaddleApiImpl::GetOutputNames() { return raw_predictor_.GetOutputNames(); } diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 663d3c4ed35e892f2dc054bee415c0dc1dcc7685..78071bed764a1fac42e9072d0ac04dfad959b2e8 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -173,6 +173,11 @@ std::unique_ptr PaddlePredictor::GetMutableTensor( return nullptr; } +std::vector PaddlePredictor::GetParamNames() { + LOG(FATAL) + << "The GetParamNames API is only supported by CxxConfig predictor."; +} + void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir, LiteModelType model_type, bool record_info) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index ac905560202892148b5459a2eb0ba80b107ec236..78d28f0c8315fa1ac1146e985fa775719f192e9a 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -86,6 +86,8 @@ class LITE_API PaddlePredictor { virtual std::vector GetInputNames() = 0; // Get output names virtual std::vector GetOutputNames() = 0; + // Get output names + virtual std::vector GetParamNames(); // Get Input by name virtual std::unique_ptr GetInputByName(const std::string& name) = 0;