diff --git a/lite/model_parser/flatbuffers/io_test.cc b/lite/model_parser/flatbuffers/io_test.cc index 19d586322e8016e6d0280e12a69d92a7b4c712c0..1fdd700358064f668c3c68bea4b6a3cecc4940c3 100644 --- a/lite/model_parser/flatbuffers/io_test.cc +++ b/lite/model_parser/flatbuffers/io_test.cc @@ -43,7 +43,7 @@ void set_tensor(paddle::lite::Tensor* tensor, TEST(CombinedParamsDesc, Scope) { /* --------- Save scope ---------- */ Scope scope; - std::vector params_name({"var_0", "var_1"}); + std::vector params_name({"var_0", "var_1", "var_2"}); // variable 0 Variable* var_0 = scope.Var(params_name[0]); Tensor* tensor_0 = var_0->GetMutable(); @@ -52,6 +52,10 @@ TEST(CombinedParamsDesc, Scope) { Variable* var_1 = scope.Var(params_name[1]); Tensor* tensor_1 = var_1->GetMutable(); set_tensor(tensor_1, std::vector({10, 1})); + // variable 3 + Variable* var_2 = scope.Var(params_name[2]); + Tensor* tensor_2 = var_2->GetMutable(); + set_tensor(tensor_2, std::vector({16, 1})); // Set combined parameters fbs::CombinedParamsDesc combined_param; std::set params_set(params_name.begin(), params_name.end()); @@ -71,6 +75,11 @@ TEST(CombinedParamsDesc, Scope) { CHECK(var_l1); const Tensor& tensor_l1 = var_l1->Get(); CHECK(TensorCompareWith(*tensor_1, tensor_l1)); + // variable 2 + Variable* var_l2 = scope_l.FindVar(params_name[2]); + CHECK(var_l2); + const Tensor& tensor_l2 = var_l2->Get(); + CHECK(TensorCompareWith(*tensor_2, tensor_l2)); }; check_params(combined_param);