attribute_test.cc 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/attribute.h"

17 18 19 20
#include <string>
#include <vector>

#include "gtest/gtest.h"
21
#include "paddle/fluid/framework/program_desc.h"
22
#include "paddle/fluid/framework/var_desc.h"
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
#include "paddle/utils/any.h"

TEST(Attribute, GetAttrValueToAny) {
  paddle::framework::Attribute x_int(100);
  auto rlt_int = paddle::framework::GetAttrValue(x_int);
  EXPECT_EQ(paddle::any_cast<int>(rlt_int), 100);

  float float_value = 3.14;
  paddle::framework::Attribute x_float(float_value);
  auto rlt_float = paddle::framework::GetAttrValue(x_float);
  EXPECT_NEAR(paddle::any_cast<float>(rlt_float), 3.14, 1e-6);

  std::string str_value("test");
  paddle::framework::Attribute x_str(str_value);
  auto rlt_str = paddle::framework::GetAttrValue(x_str);
  EXPECT_EQ(paddle::any_cast<std::string>(rlt_str), "test");

  std::vector<int> vec_int_var(2, 100);
  paddle::framework::Attribute x_vec_int = vec_int_var;
  auto rlt_vec_int = paddle::framework::GetAttrValue(x_vec_int);
  auto vec_int = paddle::any_cast<std::vector<int>>(rlt_vec_int);
  EXPECT_EQ(vec_int.size(), 2UL);
  EXPECT_EQ(vec_int[0], 100);
  EXPECT_EQ(vec_int[1], 100);

  std::vector<float> vec_float_var(2, 3.14);
  paddle::framework::Attribute x_vec_float = vec_float_var;
  auto rlt_vec_float = paddle::framework::GetAttrValue(x_vec_float);
  auto vec_float = paddle::any_cast<std::vector<float>>(rlt_vec_float);
  EXPECT_EQ(vec_float.size(), 2UL);
  EXPECT_NEAR(vec_float[0], 3.14, 1e-6);
  EXPECT_NEAR(vec_float[1], 3.14, 1e-6);

  std::vector<std::string> vec_str_var(2, "test");
  paddle::framework::Attribute x_vec_str = vec_str_var;
  auto rlt_vec_str = paddle::framework::GetAttrValue(x_vec_str);
  auto vec_str = paddle::any_cast<std::vector<std::string>>(rlt_vec_str);
  EXPECT_EQ(vec_str.size(), 2UL);
  EXPECT_EQ(vec_str[0], "test");
  EXPECT_EQ(vec_str[1], "test");

  paddle::framework::Attribute x_bool(true);
  auto rlt_bool = paddle::framework::GetAttrValue(x_bool);
  EXPECT_EQ(paddle::any_cast<bool>(rlt_bool), true);

  std::vector<bool> vec_bool_var(2, true);
  paddle::framework::Attribute x_vec_bool = vec_bool_var;
  auto rlt_vec_bool = paddle::framework::GetAttrValue(x_vec_bool);
  auto vec_bool = paddle::any_cast<std::vector<bool>>(rlt_vec_bool);
  EXPECT_EQ(vec_bool.size(), 2UL);
  EXPECT_EQ(vec_bool[0], true);
  EXPECT_EQ(vec_bool[1], true);

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  paddle::framework::VarDesc var_desc("axis");
  paddle::framework::Attribute var_attr(&var_desc);
  auto rlt_var_attr = paddle::framework::GetAttrValue(var_attr);
  auto var_desc_ptr =
      paddle::any_cast<paddle::framework::VarDesc*>(rlt_var_attr);
  EXPECT_NE(var_desc_ptr, nullptr);
  EXPECT_EQ(var_desc_ptr->Name(), var_desc.Name());

  paddle::framework::VarDesc var2_desc("prob");
  std::vector<paddle::framework::VarDesc*> vars_desc{&var_desc, &var2_desc};
  paddle::framework::Attribute vars_attr(vars_desc);

  auto rlt_vars_attr = paddle::framework::GetAttrValue(vars_attr);
  auto rlt_vars_desc =
      paddle::any_cast<std::vector<paddle::framework::VarDesc*>>(rlt_vars_attr);
  EXPECT_EQ(rlt_vars_desc.size(), vars_desc.size());
  EXPECT_EQ(rlt_vars_desc[0]->Name(), vars_desc[0]->Name());
  EXPECT_EQ(rlt_vars_desc[1]->Name(), vars_desc[1]->Name());

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
  paddle::framework::ProgramDesc prog;
  paddle::framework::proto::BlockDesc proto_block;
  paddle::framework::BlockDesc block_desc(&prog, &proto_block);
  paddle::framework::Attribute x_block_desc(&block_desc);
  auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc);
  auto block_desc_ptr =
      paddle::any_cast<paddle::framework::BlockDesc*>(rlt_block_desc);
  EXPECT_NE(block_desc_ptr, nullptr);

  std::vector<paddle::framework::BlockDesc*> vec_block_desc_var;
  vec_block_desc_var.emplace_back(&block_desc);
  paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var);
  auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc);
  auto vec_block_desc =
      paddle::any_cast<std::vector<paddle::framework::BlockDesc*>>(
          rlt_vec_block_desc);
  EXPECT_EQ(vec_block_desc.size(), 1UL);
  EXPECT_NE(vec_block_desc[0], nullptr);

  int64_t int64_value = 100;
  paddle::framework::Attribute x_int64(int64_value);
  auto rlt_int64 = paddle::framework::GetAttrValue(x_int64);
  EXPECT_EQ(paddle::any_cast<int64_t>(rlt_int64), 100);

  std::vector<int64_t> vec_int64_var(2, 100);
  paddle::framework::Attribute x_vec_int64 = vec_int64_var;
  auto rlt_vec_int64 = paddle::framework::GetAttrValue(x_vec_int64);
  auto vec_int64 = paddle::any_cast<std::vector<int64_t>>(rlt_vec_int64);
  EXPECT_EQ(vec_int64.size(), 2UL);
  EXPECT_EQ(vec_int64[0], 100);
  EXPECT_EQ(vec_int64[1], 100);

  std::vector<double> vec_double_var(2, 3.14);
  paddle::framework::Attribute x_vec_double = vec_double_var;
  auto rlt_vec_double = paddle::framework::GetAttrValue(x_vec_double);
  auto vec_double = paddle::any_cast<std::vector<double>>(rlt_vec_double);
  EXPECT_EQ(vec_double.size(), 2UL);
  EXPECT_NEAR(vec_double[0], 3.14, 1e-6);
  EXPECT_NEAR(vec_double[1], 3.14, 1e-6);
}