// Copyright (c) 2022 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/utils/functional.h" #include #include #include #include #include #include #include #include #include "paddle/cinn/utils/string.h" namespace cinn { namespace utils { TEST(Functional, IsVector) { static_assert(!IsVector::value, "int is not a vector"); static_assert(!IsVector::value, "string is not a vector"); static_assert(!IsVector::value, "const string* is not a vector"); static_assert(!IsVector>::value, "list is not a vector"); static_assert(!IsVector &>::value, "const list& is not a vector"); static_assert(!IsVector>::value, "set is not a vector"); static_assert(!IsVector *>::value, "set* is not a vector"); static_assert(IsVector>::value, "vector is a vector"); static_assert(IsVector &>::value, "vector& is a vector"); static_assert(IsVector *>::value, "vector* is a vector"); static_assert(IsVector>::value, "const vector is a vector"); static_assert(IsVector &>::value, "const vector& is a vector"); static_assert(IsVector *>::value, "const vector* is a vector"); static_assert(IsVector>::value, "volatile vector is a vector"); static_assert(IsVector &>::value, "volatile vector& is a vector"); static_assert(IsVector *>::value, "volatile vector* is a vector"); static_assert(IsVector>::value, "const volatile vector is a vector"); static_assert(IsVector &>::value, "const volatile vector& is a vector"); static_assert(IsVector *>::value, "const volatile vector* is a vector"); } TEST(Functional, Flatten) { double d = 3.14; auto flatten_d = Flatten(d); LOG(INFO) << utils::Join(flatten_d, ", "); ASSERT_EQ(flatten_d.size(), 1); ASSERT_TRUE(absl::c_equal(flatten_d, std::vector{3.14})); std::string s = "constant"; auto flatten_s = Flatten(s); LOG(INFO) << utils::Join(flatten_s, ", "); ASSERT_EQ(flatten_s.size(), 1); ASSERT_TRUE(absl::c_equal(flatten_s, std::vector{"constant"})); const std::string &sr = s; auto flatten_sr = Flatten(sr); LOG(INFO) << utils::Join(flatten_sr, ", "); ASSERT_EQ(flatten_sr.size(), 1); ASSERT_TRUE(absl::c_equal(flatten_sr, std::vector{"constant"})); std::vector> i{{3, 4, 5}, {7, 8, 9, 10}}; auto flatten_i = Flatten(i); LOG(INFO) << utils::Join(flatten_i, ", "); ASSERT_EQ(flatten_i.size(), 7); ASSERT_TRUE(absl::c_equal(flatten_i, std::vector{3, 4, 5, 7, 8, 9, 10})); std::vector>> v{{{true, false}, {true, false, true, false}}, {{false}, {true, true, false}}}; std::vector flatten_v = Flatten(v); LOG(INFO) << utils::Join(flatten_v, ", "); ASSERT_EQ(flatten_v.size(), 10); ASSERT_TRUE( absl::c_equal(flatten_v, std::vector{true, false, true, false, true, false, false, true, true, false})); std::vector>> str{{{"true", "false"}, {"true", "false", "true", "false"}}, {{"false"}, {"true", "true", "false"}}}; auto flatten_str = Flatten(str); LOG(INFO) << utils::Join(flatten_str, ", "); ASSERT_EQ(flatten_str.size(), 10); ASSERT_TRUE(absl::c_equal( flatten_str, std::vector{"true", "false", "true", "false", "true", "false", "false", "true", "true", "false"})); std::list>> a{{{1, 2, 3}, {1, 2, 3, 4, 5, 6}}, {{1, 2.2f, 3}, {1, 2, 3.3f, 4.5f}}}; auto flatten_a = Flatten(a); LOG(INFO) << utils::Join(flatten_a, ", "); ASSERT_EQ(flatten_a.size(), 16); ASSERT_TRUE(absl::c_equal(flatten_a, std::vector{1, 2, 3, 1, 2, 3, 4, 5, 6, 1, 2, 3.3, 4.5, 1, 2.2, 3})); std::list>> b; auto flatten_b = Flatten(b); LOG(INFO) << utils::Join(flatten_b, ", "); ASSERT_EQ(flatten_b.size(), 0); ASSERT_TRUE(absl::c_equal(flatten_b, std::vector{})); std::list>> empty_str; auto flatten_empty_str = Flatten(empty_str); LOG(INFO) << utils::Join(flatten_empty_str, ", "); ASSERT_EQ(flatten_empty_str.size(), 0); ASSERT_TRUE(absl::c_equal(flatten_empty_str, std::vector{})); } } // namespace utils } // namespace cinn