提交 fecc52a4 编写于 作者: D DannyIsFunny

add GetParamNames() test=develop

上级 37fdb393
...@@ -150,6 +150,11 @@ std::vector<std::string> Predictor::GetInputNames() { return input_names_; } ...@@ -150,6 +150,11 @@ std::vector<std::string> Predictor::GetInputNames() { return input_names_; }
// get outputnames // get outputnames
std::vector<std::string> Predictor::GetOutputNames() { return output_names_; } std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// get param names
std::vector<std::string> Predictor::GetParamNames() {
return exec_scope_->LocalVarNames();
}
// append the names of inputs and outputs into input_names_ and output_names_ // append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() { void Predictor::PrepareFeedFetch() {
if (!program_) { if (!program_) {
......
...@@ -84,6 +84,9 @@ class LITE_API Predictor { ...@@ -84,6 +84,9 @@ class LITE_API Predictor {
// get inputnames and get outputnames. // get inputnames and get outputnames.
std::vector<std::string> GetInputNames(); std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames(); std::vector<std::string> GetOutputNames();
// get param names
std::vector<std::string> GetParamNames();
void PrepareFeedFetch(); void PrepareFeedFetch();
// Get offset-th col of fetch results. // Get offset-th col of fetch results.
...@@ -144,6 +147,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -144,6 +147,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
// get inputs names and get outputs names // get inputs names and get outputs names
std::vector<std::string> GetInputNames() override; std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override; std::vector<std::string> GetOutputNames() override;
// get param names
std::vector<std::string> GetParamNames() override;
// get tensor according to tensor's name // get tensor according to tensor's name
std::unique_ptr<const lite_api::Tensor> GetTensor( std::unique_ptr<const lite_api::Tensor> GetTensor(
......
...@@ -75,6 +75,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetInputNames() { ...@@ -75,6 +75,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames(); return raw_predictor_.GetInputNames();
} }
std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_.GetParamNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames(); return raw_predictor_.GetOutputNames();
} }
......
...@@ -173,6 +173,11 @@ std::unique_ptr<Tensor> PaddlePredictor::GetMutableTensor( ...@@ -173,6 +173,11 @@ std::unique_ptr<Tensor> PaddlePredictor::GetMutableTensor(
return nullptr; return nullptr;
} }
std::vector<std::string> PaddlePredictor::GetParamNames() {
LOG(FATAL)
<< "The GetParamNames API is only supported by CxxConfig predictor.";
}
void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir, void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir,
LiteModelType model_type, LiteModelType model_type,
bool record_info) { bool record_info) {
......
...@@ -86,6 +86,8 @@ class LITE_API PaddlePredictor { ...@@ -86,6 +86,8 @@ class LITE_API PaddlePredictor {
virtual std::vector<std::string> GetInputNames() = 0; virtual std::vector<std::string> GetInputNames() = 0;
// Get output names // Get output names
virtual std::vector<std::string> GetOutputNames() = 0; virtual std::vector<std::string> GetOutputNames() = 0;
// Get output names
virtual std::vector<std::string> GetParamNames();
// Get Input by name // Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0; virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册