提交 fecc52a4 编写于 作者: D DannyIsFunny

add GetParamNames() test=develop

上级 37fdb393
......@@ -150,6 +150,11 @@ std::vector<std::string> Predictor::GetInputNames() { return input_names_; }
// get outputnames
std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
// get param names
std::vector<std::string> Predictor::GetParamNames() {
return exec_scope_->LocalVarNames();
}
// append the names of inputs and outputs into input_names_ and output_names_
void Predictor::PrepareFeedFetch() {
if (!program_) {
......
......@@ -84,6 +84,9 @@ class LITE_API Predictor {
// get inputnames and get outputnames.
std::vector<std::string> GetInputNames();
std::vector<std::string> GetOutputNames();
// get param names
std::vector<std::string> GetParamNames();
void PrepareFeedFetch();
// Get offset-th col of fetch results.
......@@ -144,6 +147,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
// get inputs names and get outputs names
std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override;
// get param names
std::vector<std::string> GetParamNames() override;
// get tensor according to tensor's name
std::unique_ptr<const lite_api::Tensor> GetTensor(
......
......@@ -75,6 +75,10 @@ std::vector<std::string> CxxPaddleApiImpl::GetInputNames() {
return raw_predictor_.GetInputNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetParamNames() {
return raw_predictor_.GetParamNames();
}
std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
return raw_predictor_.GetOutputNames();
}
......
......@@ -173,6 +173,11 @@ std::unique_ptr<Tensor> PaddlePredictor::GetMutableTensor(
return nullptr;
}
std::vector<std::string> PaddlePredictor::GetParamNames() {
LOG(FATAL)
<< "The GetParamNames API is only supported by CxxConfig predictor.";
}
void PaddlePredictor::SaveOptimizedModel(const std::string &model_dir,
LiteModelType model_type,
bool record_info) {
......
......@@ -86,6 +86,8 @@ class LITE_API PaddlePredictor {
virtual std::vector<std::string> GetInputNames() = 0;
// Get output names
virtual std::vector<std::string> GetOutputNames() = 0;
// Get output names
virtual std::vector<std::string> GetParamNames();
// Get Input by name
virtual std::unique_ptr<Tensor> GetInputByName(const std::string& name) = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册