提交 12a3cea0 编写于 作者: C chengduo 提交者: Abhinav Arora

Add tuple type (#8519)

* add the type of tuple

* add lod_tensor to tuple
上级 d3fbede9
...@@ -96,5 +96,6 @@ cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_contex ...@@ -96,5 +96,6 @@ cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_contex
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(channel_test SRCS channel_test.cc) cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc) channel_send_op channel_recv_op sum_op elementwise_add_op executor proto_desc)
...@@ -117,6 +117,7 @@ message VarType { ...@@ -117,6 +117,7 @@ message VarType {
// raw variables should manage their own allocations // raw variables should manage their own allocations
// in operators like nccl_op // in operators like nccl_op
RAW = 17; RAW = 17;
TUPLE = 18;
} }
required Type type = 1; required Type type = 1;
...@@ -148,6 +149,9 @@ message VarType { ...@@ -148,6 +149,9 @@ message VarType {
required int64 capacity = 2; required int64 capacity = 2;
} }
optional ChannelDesc channel = 6; optional ChannelDesc channel = 6;
message Tuple { repeated Type element_type = 1; }
optional Tuple tuple = 7;
} }
message VarDesc { message VarDesc {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
typedef boost::variant<int, int64_t, float, double, std::string, Tensor,
LoDTensor /*, ChannelHolder*/>
ElementVar;
class Tuple {
public:
using ElementVars = std::vector<ElementVar>;
Tuple(std::vector<ElementVar>& var, std::vector<VarDesc>& var_desc)
: var_(var), var_desc_(var_desc) {}
Tuple(std::vector<ElementVar>& var) : var_(var) {}
ElementVar get(int idx) const { return var_[idx]; };
ElementVar& get(int idx) { return var_[idx]; };
bool isSameType(Tuple& t) const;
size_t getSize() const { return var_.size(); };
private:
ElementVars var_;
std::vector<VarDesc> var_desc_;
};
bool Tuple::isSameType(Tuple& t) const {
size_t tuple_size = getSize();
if (tuple_size != t.getSize()) {
return false;
}
for (size_t j = 0; j < tuple_size; ++j) {
auto type1 = get(j).which();
auto type2 = t.get(j).which();
if (type1 != type2) return false;
}
return true;
}
Tuple* make_tuple(std::vector<ElementVar> tuple) { return new Tuple(tuple); }
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sstream>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tuple.h"
TEST(Tuple, Make) {
std::vector<paddle::framework::ElementVar> element_type;
element_type.push_back(12);
element_type.push_back(12.0f);
element_type.push_back("ElementVar");
paddle::framework::Tuple* tuple = paddle::framework::make_tuple(element_type);
EXPECT_EQ(boost::get<int>(tuple->get(0)), 12);
EXPECT_EQ(boost::get<float>(tuple->get(1)), 12.0f);
EXPECT_EQ(boost::get<std::string>(tuple->get(2)), "ElementVar");
delete tuple;
}
TEST(Tuple, IsTheSameType) {
std::vector<paddle::framework::ElementVar> element_type1;
std::vector<paddle::framework::ElementVar> element_type2;
std::vector<paddle::framework::ElementVar> element_type3;
element_type1.push_back(12);
element_type1.push_back(12.0f);
element_type1.push_back("Tuple1");
element_type2.push_back(13);
element_type2.push_back(13.0f);
element_type2.push_back("Tuple2");
element_type3.push_back(14.0f);
element_type3.push_back(14);
element_type3.push_back("Tuple3");
paddle::framework::Tuple* tuple1 =
paddle::framework::make_tuple(element_type1);
paddle::framework::Tuple* tuple2 =
paddle::framework::make_tuple(element_type2);
paddle::framework::Tuple* tuple3 =
paddle::framework::make_tuple(element_type3);
EXPECT_TRUE(tuple1->isSameType(*tuple2));
EXPECT_FALSE(tuple1->isSameType(*tuple3));
delete tuple1;
delete tuple2;
delete tuple3;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册