未验证 提交 86070597 编写于 作者: 石晓伟 提交者: GitHub

add test cases for flatbuffers, test=develop (#4282)

上级 a17d7be3
...@@ -43,7 +43,7 @@ void set_tensor(paddle::lite::Tensor* tensor, ...@@ -43,7 +43,7 @@ void set_tensor(paddle::lite::Tensor* tensor,
TEST(CombinedParamsDesc, Scope) { TEST(CombinedParamsDesc, Scope) {
/* --------- Save scope ---------- */ /* --------- Save scope ---------- */
Scope scope; Scope scope;
std::vector<std::string> params_name({"var_0", "var_1"}); std::vector<std::string> params_name({"var_0", "var_1", "var_2"});
// variable 0 // variable 0
Variable* var_0 = scope.Var(params_name[0]); Variable* var_0 = scope.Var(params_name[0]);
Tensor* tensor_0 = var_0->GetMutable<Tensor>(); Tensor* tensor_0 = var_0->GetMutable<Tensor>();
...@@ -52,6 +52,10 @@ TEST(CombinedParamsDesc, Scope) { ...@@ -52,6 +52,10 @@ TEST(CombinedParamsDesc, Scope) {
Variable* var_1 = scope.Var(params_name[1]); Variable* var_1 = scope.Var(params_name[1]);
Tensor* tensor_1 = var_1->GetMutable<Tensor>(); Tensor* tensor_1 = var_1->GetMutable<Tensor>();
set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1})); set_tensor<int8_t>(tensor_1, std::vector<int64_t>({10, 1}));
// variable 3
Variable* var_2 = scope.Var(params_name[2]);
Tensor* tensor_2 = var_2->GetMutable<Tensor>();
set_tensor<int16_t>(tensor_2, std::vector<int64_t>({16, 1}));
// Set combined parameters // Set combined parameters
fbs::CombinedParamsDesc combined_param; fbs::CombinedParamsDesc combined_param;
std::set<std::string> params_set(params_name.begin(), params_name.end()); std::set<std::string> params_set(params_name.begin(), params_name.end());
...@@ -71,6 +75,11 @@ TEST(CombinedParamsDesc, Scope) { ...@@ -71,6 +75,11 @@ TEST(CombinedParamsDesc, Scope) {
CHECK(var_l1); CHECK(var_l1);
const Tensor& tensor_l1 = var_l1->Get<Tensor>(); const Tensor& tensor_l1 = var_l1->Get<Tensor>();
CHECK(TensorCompareWith(*tensor_1, tensor_l1)); CHECK(TensorCompareWith(*tensor_1, tensor_l1));
// variable 2
Variable* var_l2 = scope_l.FindVar(params_name[2]);
CHECK(var_l2);
const Tensor& tensor_l2 = var_l2->Get<Tensor>();
CHECK(TensorCompareWith(*tensor_2, tensor_l2));
}; };
check_params(combined_param); check_params(combined_param);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册