From 12a3cea0879b625829d14729ec4bc180f013385c Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 1 Mar 2018 14:17:54 +0800 Subject: [PATCH] Add tuple type (#8519) * add the type of tuple * add lod_tensor to tuple --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/framework.proto | 4 ++ paddle/fluid/framework/tuple.h | 71 ++++++++++++++++++++++++++ paddle/fluid/framework/tuple_test.cc | 65 +++++++++++++++++++++++ 4 files changed, 141 insertions(+) create mode 100644 paddle/fluid/framework/tuple.h create mode 100644 paddle/fluid/framework/tuple_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e076b5003ba..82c7d4a2ec6 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -96,5 +96,6 @@ cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_contex cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(channel_test SRCS channel_test.cc) +cc_test(tuple_test SRCS tuple_test.cc ) cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc) diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 38f22b89143..96f53dc1bc8 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -117,6 +117,7 @@ message VarType { // raw variables should manage their own allocations // in operators like nccl_op RAW = 17; + TUPLE = 18; } required Type type = 1; @@ -148,6 +149,9 @@ message VarType { required int64 capacity = 2; } optional ChannelDesc channel = 6; + + message Tuple { repeated Type element_type = 1; } + optional Tuple tuple = 7; } message VarDesc { diff --git a/paddle/fluid/framework/tuple.h b/paddle/fluid/framework/tuple.h new file mode 100644 index 00000000000..78996908b18 --- /dev/null +++ b/paddle/fluid/framework/tuple.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace framework { + +typedef boost::variant + ElementVar; + +class Tuple { + public: + using ElementVars = std::vector; + + Tuple(std::vector& var, std::vector& var_desc) + : var_(var), var_desc_(var_desc) {} + Tuple(std::vector& var) : var_(var) {} + + ElementVar get(int idx) const { return var_[idx]; }; + + ElementVar& get(int idx) { return var_[idx]; }; + + bool isSameType(Tuple& t) const; + + size_t getSize() const { return var_.size(); }; + + private: + ElementVars var_; + std::vector var_desc_; +}; + +bool Tuple::isSameType(Tuple& t) const { + size_t tuple_size = getSize(); + if (tuple_size != t.getSize()) { + return false; + } + for (size_t j = 0; j < tuple_size; ++j) { + auto type1 = get(j).which(); + auto type2 = t.get(j).which(); + if (type1 != type2) return false; + } + return true; +} + +Tuple* make_tuple(std::vector tuple) { return new Tuple(tuple); } + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/tuple_test.cc b/paddle/fluid/framework/tuple_test.cc new file mode 100644 index 00000000000..810900f161c --- /dev/null +++ b/paddle/fluid/framework/tuple_test.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tuple.h" + +TEST(Tuple, Make) { + std::vector element_type; + element_type.push_back(12); + element_type.push_back(12.0f); + element_type.push_back("ElementVar"); + + paddle::framework::Tuple* tuple = paddle::framework::make_tuple(element_type); + + EXPECT_EQ(boost::get(tuple->get(0)), 12); + EXPECT_EQ(boost::get(tuple->get(1)), 12.0f); + EXPECT_EQ(boost::get(tuple->get(2)), "ElementVar"); + + delete tuple; +} + +TEST(Tuple, IsTheSameType) { + std::vector element_type1; + std::vector element_type2; + std::vector element_type3; + + element_type1.push_back(12); + element_type1.push_back(12.0f); + element_type1.push_back("Tuple1"); + + element_type2.push_back(13); + element_type2.push_back(13.0f); + element_type2.push_back("Tuple2"); + + element_type3.push_back(14.0f); + element_type3.push_back(14); + element_type3.push_back("Tuple3"); + + paddle::framework::Tuple* tuple1 = + paddle::framework::make_tuple(element_type1); + paddle::framework::Tuple* tuple2 = + paddle::framework::make_tuple(element_type2); + paddle::framework::Tuple* tuple3 = + paddle::framework::make_tuple(element_type3); + + EXPECT_TRUE(tuple1->isSameType(*tuple2)); + EXPECT_FALSE(tuple1->isSameType(*tuple3)); + + delete tuple1; + delete tuple2; + delete tuple3; +} -- GitLab