From f97f69feec521e438a94cdfd1ac2355f7b72dacd Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 26 Dec 2017 19:10:11 +0800 Subject: [PATCH] Add data transform fn (#6953) * init data_transform * complete DataTransform * fix build error * add data_transform_test * add a register test for data_transform_fn * use function to simulate registration macro * add register macro * update test * clean code * restore unrelated code * update data transform test * generate unique name for REGISTER_DATA_TRANSFORM_FN * add const * follow comment * update KernelTypePair hash function --- paddle/framework/CMakeLists.txt | 5 +- paddle/framework/data_transform.cc | 26 ++++++ paddle/framework/data_transform.h | 110 ++++++++++++++++++++++++ paddle/framework/data_transform_test.cc | 78 +++++++++++++++++ 4 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 paddle/framework/data_transform.cc create mode 100644 paddle/framework/data_transform.h create mode 100644 paddle/framework/data_transform_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index c2a57a95ee6..968fefcfafb 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -64,4 +64,7 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) -cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context) +cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) + +cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) +cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc new file mode 100644 index 00000000000..35f16025a9a --- /dev/null +++ b/paddle/framework/data_transform.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_transform.h" + +namespace paddle { +namespace framework { + +DataTransformFnMap& DataTransformFnMap::Instance() { + static DataTransformFnMap data_transform_map; + return data_transform_map; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h new file mode 100644 index 00000000000..c83c08ba5ce --- /dev/null +++ b/paddle/framework/data_transform.h @@ -0,0 +1,110 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/tensor.h" +#include "paddle/framework/variable.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/macros.h" + +namespace paddle { +namespace framework { + +using DataTransformFN = + std::function ctx, + const Variable& in, Variable* out)>; +using KernelTypePair = std::pair; + +static void hash_combine(std::size_t& seed, const OpKernelType& t) { + OpKernelType::Hash kernel_type_hasher; + seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +struct KernelTypePairHash { + size_t operator()(const KernelTypePair& kernel_pair) const { + std::size_t seed = 0; + hash_combine(seed, kernel_pair.first); + hash_combine(seed, kernel_pair.second); + + return seed; + } +}; + +using DataTransformMap = + std::unordered_map; + +class DataTransformFnMap { + public: + static DataTransformFnMap& Instance(); + + bool Has(const KernelTypePair& key_pair) const { + return map_.find(key_pair) != map_.end(); + } + + void Insert(const OpKernelType& left, const OpKernelType& right, + const DataTransformFN& data_tranform_fn) { + Insert(std::make_pair(left, right), data_tranform_fn); + } + + void Insert(const KernelTypePair& kernel_type_pair, + const DataTransformFN& data_tranform_fn) { + PADDLE_ENFORCE(!Has(kernel_type_pair), + "KernelTypePair %s has been registered", ""); + map_.insert({kernel_type_pair, data_tranform_fn}); + } + + const DataTransformFN& Get(const KernelTypePair& key_pair) const { + auto data_transformer = GetNullable(key_pair); + PADDLE_ENFORCE_NOT_NULL(data_transformer, + "DataTransformFN should not be NULL"); + return *data_transformer; + } + + const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const { + auto it = map_.find(key_pair); + if (it == map_.end()) { + return nullptr; + } else { + return &(it->second); + } + } + + const DataTransformMap& Map() const { return map_; } + + private: + DataTransformFnMap() = default; + DataTransformMap map_; + DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); +}; + +// generate unique name with __LINE__ +// refs https://stackoverflow.com/questions/1597007 +#define TOKENPASTE(x, y) x##y +#define TOKENPASTE2(x, y) TOKENPASTE(x, y) +#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \ + static int TOKENPASTE2(fn_, __LINE__)() { \ + ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \ + return 0; \ + } \ + static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \ + TOKENPASTE2(fn_, __LINE__)() + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc new file mode 100644 index 00000000000..f93a47eeb56 --- /dev/null +++ b/paddle/framework/data_transform_test.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_transform.h" +#include + +namespace paddle { +namespace framework { + +using namespace platform; + +int test_value = 0; + +OpKernelType kernel_type_1(proto::DataType::FP32, CPUPlace(), DataLayout::kNCHW, + LibraryType::kCUDNN); +OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0), + DataLayout::kNCHW, LibraryType::kCUDNN); +OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0), + DataLayout::kNCHW, LibraryType::kCUDNN); + +void type1_to_type2(std::vector ctx, + const Variable& in, Variable* out) { + test_value++; +} + +void type2_to_type3(std::vector ctx, + const Variable& in, Variable* out) { + test_value--; +} + +void type1_to_type3(std::vector ctx, + const Variable& in, Variable* out) { + test_value += 2; +} + +} // namespace framework +} // namespace paddle + +namespace frw = paddle::framework; + +REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2, + frw::type1_to_type2); +REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_2, frw::kernel_type_3, + frw::type2_to_type3); +REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_3, + frw::type1_to_type3); + +TEST(DataTransform, Register) { + using namespace paddle::framework; + using namespace paddle::platform; + + auto& instance = DataTransformFnMap::Instance(); + ASSERT_EQ(instance.Map().size(), 3UL); + std::vector ctx; + paddle::framework::Variable in; + paddle::framework::Variable out; + + instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in, + &out); + ASSERT_EQ(test_value, 1); + instance.Get(std::make_pair(frw::kernel_type_2, frw::kernel_type_3))(ctx, in, + &out); + ASSERT_EQ(test_value, 0); + instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_3))(ctx, in, + &out); + ASSERT_EQ(test_value, 2); +} -- GitLab