From 377424bf21bd6f9911aa55f3ba8161363382a115 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 10 Jan 2018 19:12:06 +0800 Subject: [PATCH] reorganize data transform related code (#7391) * init data_type_transform * split data_layout_transform * tmp rm data_transform_test * change device_data_transform to data_device_transform * clean code * clean code --- paddle/framework/CMakeLists.txt | 10 +- ..._transform.cc => data_device_transform.cc} | 2 +- ...ta_transform.h => data_device_transform.h} | 0 ..._test.cu => data_device_transform_test.cu} | 0 paddle/framework/data_layout.h | 1 - paddle/framework/data_layout_transform.cc | 82 +++++++++ paddle/framework/data_layout_transform.h | 31 ++++ paddle/framework/data_transform.cc | 141 +-------------- paddle/framework/data_transform.h | 140 --------------- paddle/framework/data_transform_test.cc | 168 ------------------ paddle/framework/data_type_transform.cc | 99 +++++++++++ paddle/framework/data_type_transform.h | 31 ++++ paddle/framework/operator.cc | 1 - 13 files changed, 252 insertions(+), 454 deletions(-) rename paddle/framework/{device_data_transform.cc => data_device_transform.cc} (96%) rename paddle/framework/{device_data_transform.h => data_device_transform.h} (100%) rename paddle/framework/{device_data_transform_test.cu => data_device_transform_test.cu} (100%) create mode 100644 paddle/framework/data_layout_transform.cc create mode 100644 paddle/framework/data_layout_transform.h delete mode 100644 paddle/framework/data_transform_test.cc create mode 100644 paddle/framework/data_type_transform.cc create mode 100644 paddle/framework/data_type_transform.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index af4079875..ed5f6310f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -32,10 +32,12 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_test(scope_test SRCS scope_test.cc DEPS scope) -cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor) +cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) +cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) +cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) -cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform) -cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) +cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor + framework_proto selected_rows data_device_transform data_type_transform data_layout_transform) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc @@ -80,5 +82,5 @@ cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) -nv_test(device_data_transform_test SRCS device_data_transform_test.cu +nv_test(data_device_transform_test SRCS data_device_transform_test.cu DEPS operator op_registry init math_function) diff --git a/paddle/framework/device_data_transform.cc b/paddle/framework/data_device_transform.cc similarity index 96% rename from paddle/framework/device_data_transform.cc rename to paddle/framework/data_device_transform.cc index cd5104cc6..b3fd48ae1 100644 --- a/paddle/framework/device_data_transform.cc +++ b/paddle/framework/data_device_transform.cc @@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/device_data_transform.h" +#include "paddle/framework/data_device_transform.h" namespace paddle { namespace framework { diff --git a/paddle/framework/device_data_transform.h b/paddle/framework/data_device_transform.h similarity index 100% rename from paddle/framework/device_data_transform.h rename to paddle/framework/data_device_transform.h diff --git a/paddle/framework/device_data_transform_test.cu b/paddle/framework/data_device_transform_test.cu similarity index 100% rename from paddle/framework/device_data_transform_test.cu rename to paddle/framework/data_device_transform_test.cu diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h index 4a8669c3a..3ab976eca 100644 --- a/paddle/framework/data_layout.h +++ b/paddle/framework/data_layout.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/platform/enforce.h" #include #include "paddle/platform/enforce.h" diff --git a/paddle/framework/data_layout_transform.cc b/paddle/framework/data_layout_transform.cc new file mode 100644 index 000000000..96794cae9 --- /dev/null +++ b/paddle/framework/data_layout_transform.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_layout_transform.h" + +#include "paddle/framework/tensor.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace framework { + +struct CastDataLayout { + CastDataLayout(const platform::DeviceContext* ctx, + const std::vector& axis, const framework::Tensor& in, + framework::Tensor* out) + : in_(in), out_(out), ctx_(ctx), axis_(axis) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + const std::vector axis_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + if (platform::is_cpu_place(place)) { + operators::math::Transpose trans4; + auto* context = static_cast(ctx_); + trans4(*context, in_, out_, axis_); + } else { + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataLayout only support DataLayout transform on same place!"); + PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, + "TransDataLayout only support Datatype are same!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); + + auto src_dim = src.dims(); + std::vector dst_dim; + + dst_dim.resize(axis.size()); + for (size_t i = 0; i < axis.size(); i++) { + dst_dim[i] = src_dim[axis[i]]; + } + + dst->Resize(make_ddim(dst_dim)); + auto place = kernel_pair.second.place_; + dst->mutable_data(place, src.type()); + + auto src_type = kernel_pair.first.data_type_; + framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); + + dst->set_layout(kernel_pair.second.data_layout_); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_layout_transform.h b/paddle/framework/data_layout_transform.h new file mode 100644 index 000000000..befae1f63 --- /dev/null +++ b/paddle/framework/data_layout_transform.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/variable.h" + +namespace paddle { +namespace framework { + +using KernelTypePair = std::pair; + +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index fed958db1..e56edb953 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -11,22 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "paddle/framework/data_transform.h" -#include "paddle/framework/device_data_transform.h" -#include "paddle/framework/lod_tensor.h" -#include "paddle/framework/selected_rows.h" -#include "paddle/platform/device_context.h" + +#include "paddle/framework/data_device_transform.h" namespace paddle { namespace framework { -DataTransformFnMap& DataTransformFnMap::Instance() { - static DataTransformFnMap data_transform_map; - return data_transform_map; -} - Tensor* DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, const Tensor& input_tensor) { @@ -58,134 +50,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, } } -auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNHWC, LibraryType::kPlain); - -auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), - DataLayout::kNCHW, LibraryType::kPlain); - -// TODO(dzhwinter): Only for testing multiple op kernel. -// Dummy transform function for library_type -// should be removed. -auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), - DataLayout::kAnyLayout, LibraryType::kPlain); - -auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0), - DataLayout::kAnyLayout, LibraryType::kCUDNN); - -void DummyTrans(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - auto src = in.Get(); - auto* dst = out->GetMutable(); - *dst = src; -} - -void TransDataType(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); - - auto dims = src.dims(); - dst->Resize(dims); - auto dst_type = kernel_pair.second.data_type_; - auto src_type = kernel_pair.first.data_type_; - - switch (src_type) { - case proto::DataType::FP32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::FP64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::INT32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::INT64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - case proto::DataType::BOOL: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); - break; - default: - PADDLE_THROW("Not support type %d", src_type); - } -} - -void TransDataLayout(const std::vector& axis, - const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataLayout only support DataLayout transform on same place!"); - PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, - "TransDataLayout only support Datatype are same!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); - PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); - - auto src_dim = src.dims(); - std::vector dst_dim; - - dst_dim.resize(axis.size()); - for (size_t i = 0; i < axis.size(); i++) { - dst_dim[i] = src_dim[axis[i]]; - } - - dst->Resize(make_ddim(dst_dim)); - auto place = kernel_pair.second.place_; - dst->mutable_data(place, src.type()); - - auto src_type = kernel_pair.first.data_type_; - framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); - - dst->set_layout(kernel_pair.second.data_layout_); -} - } // namespace framework } // namespace paddle - -namespace f = paddle::framework; - -namespace { -std::vector NHWC2NCHW = {0, 3, 1, 2}; -std::vector NCHW2NHWC = {0, 2, 3, 1}; -} - -REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); -REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans); -REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans); -REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, - std::bind(f::TransDataLayout, NHWC2NCHW, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3, - std::placeholders::_4)); -REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC, - std::bind(f::TransDataLayout, NCHW2NHWC, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3, - std::placeholders::_4)); diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index e4e5c30a9..ee95c7e85 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -30,26 +30,6 @@ limitations under the License. */ namespace paddle { namespace framework { -using KernelTypePair = std::pair; - -using DataTransformFn = - std::function; - -struct KernelTypePairHash { - static void HashCombine(const OpKernelType& t, std::size_t* seed) { - OpKernelType::Hash kernel_type_hasher; - (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); - } - - size_t operator()(const KernelTypePair& kernel_pair) const { - std::size_t seed = 0; - HashCombine(kernel_pair.first, &seed); - HashCombine(kernel_pair.second, &seed); - return seed; - } -}; - Tensor* DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, const Tensor& input_tensor); @@ -57,125 +37,5 @@ Tensor* DataTransform(const OpKernelType& expected_kernel_type, void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, Variable& out_var); -template -struct CastDataTypeFunctor { - HOSTDEVICE inline OutType operator()(InType in) const { - return static_cast(in); - } -}; - -template -struct CastDataType { - CastDataType(const framework::Tensor& in, framework::Tensor* out, - const platform::DeviceContext* ctx) - : in_(in), out_(out), ctx_(ctx) {} - const framework::Tensor in_; - framework::Tensor* out_; - const platform::DeviceContext* ctx_; - - template - void operator()() { - auto place = ctx_->GetPlace(); - - auto* in_begin = in_.data(); - auto numel = in_.numel(); - auto* in_end = in_begin + numel; - auto* out_begin = out_->mutable_data(place); - - if (platform::is_cpu_place(place)) { - platform::Transform trans; - auto* context = static_cast(ctx_); - trans(*context, in_begin, in_end, out_begin, - CastDataTypeFunctor()); - } else { - // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? - PADDLE_THROW("Unsupport CPU <-> GPU!"); - } - } -}; - -struct CastDataLayout { - CastDataLayout(const platform::DeviceContext* ctx, - const std::vector& axis, const framework::Tensor& in, - framework::Tensor* out) - : in_(in), out_(out), ctx_(ctx), axis_(axis) {} - const framework::Tensor in_; - framework::Tensor* out_; - const platform::DeviceContext* ctx_; - const std::vector axis_; - - template - void operator()() { - auto place = ctx_->GetPlace(); - - if (platform::is_cpu_place(place)) { - operators::math::Transpose trans4; - auto* context = static_cast(ctx_); - trans4(*context, in_, out_, axis_); - } else { - PADDLE_THROW("Unsupport CPU <-> GPU!"); - } - } -}; - -using DataTransformMap = - std::unordered_map; - -class DataTransformFnMap { - public: - static DataTransformFnMap& Instance(); - - bool Has(const KernelTypePair& key_pair) const { - return map_.find(key_pair) != map_.end(); - } - - void Insert(const OpKernelType& left, const OpKernelType& right, - const DataTransformFn& data_tranform_fn) { - Insert(std::make_pair(left, right), data_tranform_fn); - } - - void Insert(const KernelTypePair& kernel_type_pair, - const DataTransformFn& data_tranform_fn) { - PADDLE_ENFORCE(!Has(kernel_type_pair), - "KernelTypePair %s has been registered", ""); - map_.insert({kernel_type_pair, data_tranform_fn}); - } - - const DataTransformFn& Get(const KernelTypePair& key_pair) const { - auto data_transformer = GetNullable(key_pair); - PADDLE_ENFORCE_NOT_NULL(data_transformer, - "DataTransformFn should not be NULL"); - return *data_transformer; - } - - const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const { - auto it = map_.find(key_pair); - if (it == map_.end()) { - return nullptr; - } else { - return &(it->second); - } - } - - const DataTransformMap& Map() const { return map_; } - - private: - DataTransformFnMap() = default; - DataTransformMap map_; - DISABLE_COPY_AND_ASSIGN(DataTransformFnMap); -}; - -// generate unique name with __LINE__ -// refs https://stackoverflow.com/questions/1597007 -#define TOKENPASTE(x, y) x##y -#define TOKENPASTE2(x, y) TOKENPASTE(x, y) -#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \ - static int TOKENPASTE2(fn_, __LINE__)() { \ - ::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \ - return 0; \ - } \ - static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \ - TOKENPASTE2(fn_, __LINE__)() - } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc deleted file mode 100644 index edd305fd1..000000000 --- a/paddle/framework/data_transform_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include -#include - -#include - -#include "paddle/framework/data_transform.h" -#include "paddle/platform/device_context.h" - -namespace paddle { -namespace framework { -using namespace platform; - -/** - * @brief cross validation of different kernel type transform - * We use four bit map represent different combination. - * If the field has multiple possible value, only choose two of them. - * For DataType, only test the FP32(float), FP64(double). - * e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain - * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN - */ - -std::array kDataType = { - {proto::DataType::FP32, proto::DataType::FP64}}; - -std::array kPlace = {{CPUPlace(), CUDAPlace(0)}}; - -std::array kDataLayout = {{ - DataLayout::kNHWC, DataLayout::kNCHW, -}}; - -std::array kLibraryType = {{ - LibraryType::kPlain, LibraryType::kMKLDNN, -}}; - -OpKernelType GenFromBit(const std::vector bits) { - return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]], - kLibraryType[bits[3]]); -} - -int test_value = 0; - -auto kernel0 = GenFromBit({0, 0, 0, 0}); -auto kernel1 = GenFromBit({0, 0, 0, 1}); -auto kernel2 = GenFromBit({0, 0, 1, 0}); -auto kernel3 = GenFromBit({0, 0, 1, 1}); - -void TransDataType_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value++; -} - -void TransDataLayout_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value--; -} - -void TransLibraryType_t(const platform::DeviceContext* ctx, - const KernelTypePair& p, const Variable& in, - Variable* out) { - test_value += 2; -} - -} // namespace framework -} // namespace paddle - -namespace frw = paddle::framework; - -REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t); -REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t); -REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t); - -TEST(DataTransform, Register) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - paddle::framework::Variable in; - paddle::framework::Variable out; - - DeviceContext* ctx = new CPUDeviceContext(); - auto pair0 = std::make_pair(frw::kernel0, frw::kernel1); - instance.Get(pair0)(ctx, pair0, in, &out); - ASSERT_EQ(test_value, 1); - - auto pair1 = std::make_pair(frw::kernel1, frw::kernel2); - instance.Get(pair1)(ctx, pair1, in, &out); - ASSERT_EQ(test_value, 0); - - auto pair3 = std::make_pair(frw::kernel0, frw::kernel2); - instance.Get(pair3)(ctx, pair3, in, &out); - ASSERT_EQ(test_value, 2); -} - -TEST(DataTransform, DataLayout) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - Variable in; - Variable out; - Tensor* src = in.GetMutable(); - src->mutable_data(make_ddim({2, 3, 1, 2}), CPUPlace()); - src->set_layout(DataLayout::kNHWC); - - DeviceContext* ctx = new CPUDeviceContext(); - - { - auto kernel1 = GenFromBit({1, 0, 0, 0}); - auto kernel2 = GenFromBit({1, 0, 1, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, in, &out); - } - - Tensor dst = out.Get(); - - EXPECT_TRUE(dst.layout() == DataLayout::kNCHW); - EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1})); - - { - auto kernel1 = GenFromBit({1, 0, 1, 0}); - auto kernel2 = GenFromBit({1, 0, 0, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, out, &in); - } - - EXPECT_TRUE(src->layout() == DataLayout::kNHWC); - EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2})); -} - -TEST(DataTransform, DataType) { - using namespace paddle::framework; - using namespace paddle::platform; - - auto& instance = DataTransformFnMap::Instance(); - DeviceContext* ctx = new CPUDeviceContext(); - - Variable in; - Variable out; - Tensor* src = in.GetMutable(); - float* ptr = src->mutable_data(make_ddim({2, 3}), CPUPlace()); - for (int i = 0; i < 6; ++i) { - ptr[i] = i / 3; - } - - { - auto kernel1 = GenFromBit({0, 0, 0, 0}); - auto kernel2 = GenFromBit({1, 0, 0, 0}); - auto pair0 = std::make_pair(kernel1, kernel2); - instance.Get(pair0)(ctx, pair0, in, &out); - } - Tensor dst = out.Get(); - EXPECT_TRUE(dst.data() != nullptr); -} diff --git a/paddle/framework/data_type_transform.cc b/paddle/framework/data_type_transform.cc new file mode 100644 index 000000000..63373232e --- /dev/null +++ b/paddle/framework/data_type_transform.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_type_transform.h" + +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace framework { + +template +struct CastDataTypeFunctor { + HOSTDEVICE inline OutType operator()(InType in) const { + return static_cast(in); + } +}; + +template +struct CastDataType { + CastDataType(const framework::Tensor& in, framework::Tensor* out, + const platform::DeviceContext* ctx) + : in_(in), out_(out), ctx_(ctx) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + auto* in_begin = in_.data(); + auto numel = in_.numel(); + auto* in_end = in_begin + numel; + auto* out_begin = out_->mutable_data(place); + + if (platform::is_cpu_place(place)) { + platform::Transform trans; + auto* context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); + } else { + // TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataType Only Support DataType transform on same place!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + + auto dims = src.dims(); + dst->Resize(dims); + auto dst_type = kernel_pair.second.data_type_; + auto src_type = kernel_pair.first.data_type_; + + switch (src_type) { + case proto::DataType::FP32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::FP64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::BOOL: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + default: + PADDLE_THROW("Not support type %d", src_type); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/data_type_transform.h b/paddle/framework/data_type_transform.h new file mode 100644 index 000000000..8ec907422 --- /dev/null +++ b/paddle/framework/data_type_transform.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/variable.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace framework { + +using KernelTypePair = std::pair; + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 35ebe48ba..ef2c55cc3 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/framework/data_transform.h" -#include "paddle/framework/device_data_transform.h" #include "paddle/framework/executor.h" #include "paddle/framework/operator.h" #include "paddle/framework/shape_inference.h" -- GitLab