attribute_test.cc 16.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/attribute.h"

17 18 19 20
#include <string>
#include <vector>

#include "gtest/gtest.h"
21
#include "paddle/fluid/framework/program_desc.h"
22
#include "paddle/fluid/framework/var_desc.h"
23
#include "paddle/phi/common/scalar.h"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
#include "paddle/utils/any.h"

TEST(Attribute, GetAttrValueToAny) {
  paddle::framework::Attribute x_int(100);
  auto rlt_int = paddle::framework::GetAttrValue(x_int);
  EXPECT_EQ(paddle::any_cast<int>(rlt_int), 100);

  float float_value = 3.14;
  paddle::framework::Attribute x_float(float_value);
  auto rlt_float = paddle::framework::GetAttrValue(x_float);
  EXPECT_NEAR(paddle::any_cast<float>(rlt_float), 3.14, 1e-6);

  std::string str_value("test");
  paddle::framework::Attribute x_str(str_value);
  auto rlt_str = paddle::framework::GetAttrValue(x_str);
  EXPECT_EQ(paddle::any_cast<std::string>(rlt_str), "test");

  std::vector<int> vec_int_var(2, 100);
  paddle::framework::Attribute x_vec_int = vec_int_var;
  auto rlt_vec_int = paddle::framework::GetAttrValue(x_vec_int);
  auto vec_int = paddle::any_cast<std::vector<int>>(rlt_vec_int);
  EXPECT_EQ(vec_int.size(), 2UL);
  EXPECT_EQ(vec_int[0], 100);
  EXPECT_EQ(vec_int[1], 100);

  std::vector<float> vec_float_var(2, 3.14);
  paddle::framework::Attribute x_vec_float = vec_float_var;
  auto rlt_vec_float = paddle::framework::GetAttrValue(x_vec_float);
  auto vec_float = paddle::any_cast<std::vector<float>>(rlt_vec_float);
  EXPECT_EQ(vec_float.size(), 2UL);
  EXPECT_NEAR(vec_float[0], 3.14, 1e-6);
  EXPECT_NEAR(vec_float[1], 3.14, 1e-6);

  std::vector<std::string> vec_str_var(2, "test");
  paddle::framework::Attribute x_vec_str = vec_str_var;
  auto rlt_vec_str = paddle::framework::GetAttrValue(x_vec_str);
  auto vec_str = paddle::any_cast<std::vector<std::string>>(rlt_vec_str);
  EXPECT_EQ(vec_str.size(), 2UL);
  EXPECT_EQ(vec_str[0], "test");
  EXPECT_EQ(vec_str[1], "test");

  paddle::framework::Attribute x_bool(true);
  auto rlt_bool = paddle::framework::GetAttrValue(x_bool);
  EXPECT_EQ(paddle::any_cast<bool>(rlt_bool), true);

  std::vector<bool> vec_bool_var(2, true);
  paddle::framework::Attribute x_vec_bool = vec_bool_var;
  auto rlt_vec_bool = paddle::framework::GetAttrValue(x_vec_bool);
  auto vec_bool = paddle::any_cast<std::vector<bool>>(rlt_vec_bool);
  EXPECT_EQ(vec_bool.size(), 2UL);
  EXPECT_EQ(vec_bool[0], true);
  EXPECT_EQ(vec_bool[1], true);

77 78 79 80
  paddle::framework::VarDesc var_desc("axis");
  paddle::framework::Attribute var_attr(&var_desc);
  auto rlt_var_attr = paddle::framework::GetAttrValue(var_attr);
  auto var_desc_ptr =
81
      paddle::any_cast<paddle::framework::VarDesc *>(rlt_var_attr);
82 83 84 85
  EXPECT_NE(var_desc_ptr, nullptr);
  EXPECT_EQ(var_desc_ptr->Name(), var_desc.Name());

  paddle::framework::VarDesc var2_desc("prob");
86
  std::vector<paddle::framework::VarDesc *> vars_desc{&var_desc, &var2_desc};
87 88 89 90
  paddle::framework::Attribute vars_attr(vars_desc);

  auto rlt_vars_attr = paddle::framework::GetAttrValue(vars_attr);
  auto rlt_vars_desc =
91 92
      paddle::any_cast<std::vector<paddle::framework::VarDesc *>>(
          rlt_vars_attr);
93 94 95 96
  EXPECT_EQ(rlt_vars_desc.size(), vars_desc.size());
  EXPECT_EQ(rlt_vars_desc[0]->Name(), vars_desc[0]->Name());
  EXPECT_EQ(rlt_vars_desc[1]->Name(), vars_desc[1]->Name());

97 98 99 100 101 102
  paddle::framework::ProgramDesc prog;
  paddle::framework::proto::BlockDesc proto_block;
  paddle::framework::BlockDesc block_desc(&prog, &proto_block);
  paddle::framework::Attribute x_block_desc(&block_desc);
  auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc);
  auto block_desc_ptr =
103
      paddle::any_cast<paddle::framework::BlockDesc *>(rlt_block_desc);
104 105
  EXPECT_NE(block_desc_ptr, nullptr);

106
  std::vector<paddle::framework::BlockDesc *> vec_block_desc_var;
107 108 109 110
  vec_block_desc_var.emplace_back(&block_desc);
  paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var);
  auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc);
  auto vec_block_desc =
111
      paddle::any_cast<std::vector<paddle::framework::BlockDesc *>>(
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
          rlt_vec_block_desc);
  EXPECT_EQ(vec_block_desc.size(), 1UL);
  EXPECT_NE(vec_block_desc[0], nullptr);

  int64_t int64_value = 100;
  paddle::framework::Attribute x_int64(int64_value);
  auto rlt_int64 = paddle::framework::GetAttrValue(x_int64);
  EXPECT_EQ(paddle::any_cast<int64_t>(rlt_int64), 100);

  std::vector<int64_t> vec_int64_var(2, 100);
  paddle::framework::Attribute x_vec_int64 = vec_int64_var;
  auto rlt_vec_int64 = paddle::framework::GetAttrValue(x_vec_int64);
  auto vec_int64 = paddle::any_cast<std::vector<int64_t>>(rlt_vec_int64);
  EXPECT_EQ(vec_int64.size(), 2UL);
  EXPECT_EQ(vec_int64[0], 100);
  EXPECT_EQ(vec_int64[1], 100);

  std::vector<double> vec_double_var(2, 3.14);
  paddle::framework::Attribute x_vec_double = vec_double_var;
  auto rlt_vec_double = paddle::framework::GetAttrValue(x_vec_double);
  auto vec_double = paddle::any_cast<std::vector<double>>(rlt_vec_double);
  EXPECT_EQ(vec_double.size(), 2UL);
  EXPECT_NEAR(vec_double[0], 3.14, 1e-6);
  EXPECT_NEAR(vec_double[1], 3.14, 1e-6);
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202

  double x_double_val = 42.1;
  paddle::framework::Attribute x_double(x_double_val);
  ASSERT_EQ(AttrTypeID(x_double), paddle::framework::proto::FLOAT64);
  EXPECT_NEAR(
      paddle::any_cast<double>(paddle::framework::GetAttrValue(x_double)),
      42.1,
      1e-6);

  paddle::framework::Attribute x_scalar = paddle::experimental::Scalar(42.1);
  ASSERT_EQ(AttrTypeID(x_scalar), paddle::framework::proto::SCALAR);
  EXPECT_EQ(paddle::any_cast<paddle::experimental::Scalar>(
                paddle::framework::GetAttrValue(x_scalar)),
            paddle::experimental::Scalar(42.1));

  std::vector<paddle::experimental::Scalar> scalars =
      paddle::experimental::WrapAsScalars(std::vector<int64_t>{1, 2, 3});
  paddle::framework::Attribute x_scalars(scalars);
  ASSERT_EQ(AttrTypeID(x_scalars), paddle::framework::proto::SCALARS);
  auto x_extracted =
      paddle::any_cast<std::vector<paddle::experimental::Scalar>>(
          paddle::framework::GetAttrValue(x_scalars));
  EXPECT_EQ(x_extracted.size(), 3UL);
  EXPECT_EQ(x_extracted.at(0), scalars.at(0));
  EXPECT_EQ(x_extracted.at(1), scalars.at(1));
  EXPECT_EQ(x_extracted.at(2), scalars.at(2));
}

TEST(Attribute, ProtoAttrToAttribute_double) {
  paddle::framework::proto::OpDesc::Attr proto_attr_double;
  proto_attr_double.set_name("anon");
  proto_attr_double.set_type(paddle::framework::proto::FLOAT64);
  proto_attr_double.set_float64(42.1);
  paddle::framework::Attribute attr_double =
      paddle::framework::GetAttrValue(proto_attr_double);
  ASSERT_EQ(AttrTypeID(attr_double), paddle::framework::proto::FLOAT64);
}

TEST(Attribute, ProtoAttrToAttribute_scalar) {
  paddle::framework::proto::OpDesc::Attr proto_attr_scalar;
  proto_attr_scalar.set_name("anon");
  proto_attr_scalar.set_type(paddle::framework::proto::SCALAR);

  auto s_bool = paddle::experimental::Scalar(static_cast<bool>(true));

  auto s_int8 = paddle::experimental::Scalar(static_cast<int8_t>(42.1));
  auto s_int16 = paddle::experimental::Scalar(static_cast<int16_t>(42.1));
  auto s_int32 = paddle::experimental::Scalar(static_cast<int32_t>(42.1));
  auto s_int64 = paddle::experimental::Scalar(static_cast<int64_t>(42.1));

  auto s_uint8 = paddle::experimental::Scalar(static_cast<uint8_t>(42.1));
  auto s_uint16 = paddle::experimental::Scalar(static_cast<uint16_t>(42.1));
  auto s_uint32 = paddle::experimental::Scalar(static_cast<uint32_t>(42.1));
  auto s_uint64 = paddle::experimental::Scalar(static_cast<uint64_t>(42.1));

  auto s_float16 =
      paddle::experimental::Scalar(static_cast<phi::float16>(42.1));
  auto s_bfloat16 =
      paddle::experimental::Scalar(static_cast<phi::bfloat16>(42.1));
  auto s_float = paddle::experimental::Scalar(static_cast<float>(42.1));
  auto s_double = paddle::experimental::Scalar(static_cast<double>(42.1));

  auto s_cfloat = paddle::experimental::Scalar(std::complex<float>(42.1, 42.1));
  auto s_cdouble =
      paddle::experimental::Scalar(std::complex<double>(42.1, 42.1));

  auto proto_scalar_bool = new paddle::framework::proto::Scalar;
203
  *proto_scalar_bool = paddle::framework::MakeScalarProto(s_bool);
204 205 206 207 208
  proto_attr_scalar.set_allocated_scalar(proto_scalar_bool);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_int8 = new paddle::framework::proto::Scalar;
209
  *proto_scalar_int8 = paddle::framework::MakeScalarProto(s_int8);
210 211 212 213 214
  proto_attr_scalar.set_allocated_scalar(proto_scalar_int8);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_int16 = new paddle::framework::proto::Scalar;
215
  *proto_scalar_int16 = paddle::framework::MakeScalarProto(s_int16);
216 217 218 219 220
  proto_attr_scalar.set_allocated_scalar(proto_scalar_int16);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_int32 = new paddle::framework::proto::Scalar;
221
  *proto_scalar_int32 = paddle::framework::MakeScalarProto(s_int32);
222 223 224 225 226
  proto_attr_scalar.set_allocated_scalar(proto_scalar_int32);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_int64 = new paddle::framework::proto::Scalar;
227
  *proto_scalar_int64 = paddle::framework::MakeScalarProto(s_int64);
228 229 230 231 232
  proto_attr_scalar.set_allocated_scalar(proto_scalar_int64);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_uint8 = new paddle::framework::proto::Scalar;
233
  *proto_scalar_uint8 = paddle::framework::MakeScalarProto(s_uint8);
234 235 236 237 238
  proto_attr_scalar.set_allocated_scalar(proto_scalar_uint8);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_uint16 = new paddle::framework::proto::Scalar;
239
  *proto_scalar_uint16 = paddle::framework::MakeScalarProto(s_uint16);
240 241 242 243 244
  proto_attr_scalar.set_allocated_scalar(proto_scalar_uint16);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_uint32 = new paddle::framework::proto::Scalar;
245
  *proto_scalar_uint32 = paddle::framework::MakeScalarProto(s_uint32);
246 247 248 249 250
  proto_attr_scalar.set_allocated_scalar(proto_scalar_uint32);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_uint64 = new paddle::framework::proto::Scalar;
251
  *proto_scalar_uint64 = paddle::framework::MakeScalarProto(s_uint64);
252 253 254 255 256
  proto_attr_scalar.set_allocated_scalar(proto_scalar_uint64);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_float16 = new paddle::framework::proto::Scalar;
257
  *proto_scalar_float16 = paddle::framework::MakeScalarProto(s_float16);
258 259 260 261 262
  proto_attr_scalar.set_allocated_scalar(proto_scalar_float16);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_bfloat16 = new paddle::framework::proto::Scalar;
263
  *proto_scalar_bfloat16 = paddle::framework::MakeScalarProto(s_bfloat16);
264 265 266 267 268
  proto_attr_scalar.set_allocated_scalar(proto_scalar_bfloat16);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_float = new paddle::framework::proto::Scalar;
269
  *proto_scalar_float = paddle::framework::MakeScalarProto(s_float);
270 271 272 273 274
  proto_attr_scalar.set_allocated_scalar(proto_scalar_float);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_double = new paddle::framework::proto::Scalar;
275
  *proto_scalar_double = paddle::framework::MakeScalarProto(s_double);
276 277 278 279 280
  proto_attr_scalar.set_allocated_scalar(proto_scalar_double);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_cfloat = new paddle::framework::proto::Scalar;
281
  *proto_scalar_cfloat = paddle::framework::MakeScalarProto(s_cfloat);
282 283 284 285 286
  proto_attr_scalar.set_allocated_scalar(proto_scalar_cfloat);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);

  auto proto_scalar_cdouble = new paddle::framework::proto::Scalar;
287
  *proto_scalar_cdouble = paddle::framework::MakeScalarProto(s_cdouble);
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
  proto_attr_scalar.set_allocated_scalar(proto_scalar_cdouble);
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalar)),
            paddle::framework::proto::SCALAR);
}

TEST(Attribute, ProtoAttrToAttribute_scalars) {
  paddle::framework::proto::OpDesc::Attr proto_attr_scalars;
  proto_attr_scalars.set_name("anon");
  proto_attr_scalars.set_type(paddle::framework::proto::SCALARS);

  std::vector<paddle::experimental::Scalar> scalars;
  for (int i = 0; i < 10; i++) {
    scalars.push_back(paddle::experimental::Scalar(i));
  }
  std::vector<paddle::framework::proto::Scalar> proto_scalars;
  proto_scalars.reserve(scalars.size());
  for (const auto &item : scalars) {
305
    proto_scalars.emplace_back(paddle::framework::MakeScalarProto(item));
306 307 308 309 310 311 312
  }
  paddle::framework::VectorToRepeated(proto_scalars,
                                      proto_attr_scalars.mutable_scalars());
  ASSERT_EQ(AttrTypeID(paddle::framework::GetAttrValue(proto_attr_scalars)),
            paddle::framework::proto::SCALARS);
}

313 314
TEST(Attribute, MakeScalarFromAttribute) {
  using paddle::framework::MakeScalarFromAttribute;
315 316 317 318 319 320 321 322 323
  auto s_bool = true;
  auto s_int32 = static_cast<int32_t>(42.1);
  auto s_int64 = static_cast<int64_t>(42.1);

  auto s_float = static_cast<float>(42.1);
  auto s_double = static_cast<double>(42.1);

  auto s_scalar = paddle::experimental::Scalar(42.1);

324
  ASSERT_EQ(MakeScalarFromAttribute(paddle::framework::Attribute(s_bool)),
325
            paddle::experimental::Scalar(s_bool));
326
  ASSERT_EQ(MakeScalarFromAttribute(paddle::framework::Attribute(s_int32)),
327
            paddle::experimental::Scalar(s_int32));
328
  ASSERT_EQ(MakeScalarFromAttribute(paddle::framework::Attribute(s_int64)),
329
            paddle::experimental::Scalar(s_int64));
330
  ASSERT_EQ(MakeScalarFromAttribute(paddle::framework::Attribute(s_float)),
331
            paddle::experimental::Scalar(s_float));
332
  ASSERT_EQ(MakeScalarFromAttribute(paddle::framework::Attribute(s_double)),
333
            paddle::experimental::Scalar(s_double));
334
  ASSERT_EQ(MakeScalarFromAttribute(paddle::framework::Attribute(s_scalar)),
335 336 337
            s_scalar);
}

338 339
TEST(Attribute, MakeScalarsFromAttribute) {
  using paddle::framework::MakeScalarsFromAttribute;
340 341 342 343 344 345 346 347
  std::vector<bool> v_bool(4, true);
  std::vector<int> v_int(4, 42);
  std::vector<int64_t> v_int64(4, 42);
  std::vector<float> v_float(4, 42.1);
  std::vector<double> v_double(4, 42.1);
  std::vector<paddle::experimental::Scalar> v_scalar(
      4, paddle::experimental::Scalar(std::complex<float>(42.1, 42.1)));

348 349
  ASSERT_EQ(MakeScalarsFromAttribute(paddle::framework::Attribute(v_bool))[0],
            paddle::experimental::Scalar(v_bool[0]));
350

351
  ASSERT_EQ(MakeScalarsFromAttribute(paddle::framework::Attribute(v_int))[0],
352
            paddle::experimental::Scalar(v_int[0]));
353 354 355 356 357 358 359 360 361
  ASSERT_EQ(MakeScalarsFromAttribute(paddle::framework::Attribute(v_int64))[0],
            paddle::experimental::Scalar(v_int64[0]));

  ASSERT_EQ(MakeScalarsFromAttribute(paddle::framework::Attribute(v_float))[0],
            paddle::experimental::Scalar(v_float[0]));
  ASSERT_EQ(MakeScalarsFromAttribute(paddle::framework::Attribute(v_double))[0],
            paddle::experimental::Scalar(v_double[0]));
  ASSERT_EQ(MakeScalarsFromAttribute(paddle::framework::Attribute(v_scalar))[0],
            v_scalar[0]);
362
}