From 899a79cceb5b949d41d25a93c6c4d79446ba41b9 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 2 Jan 2018 15:51:53 +0800 Subject: [PATCH] Feature/transform (#7111) * "fix data transform" * "data transformer" * "add device pool" * "add test" * "fix ci" * "fix datalayout implementation " * "fix based on comment" --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/data_transform.cc | 79 +++++++++++++++++++++++ paddle/framework/data_transform.h | 67 +++++++++++++++++++- paddle/framework/data_transform_test.cc | 83 +++++++++++++++++++++---- paddle/framework/operator.cc | 2 +- paddle/operators/math/math_function.cc | 9 ++- 6 files changed, 222 insertions(+), 20 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6788cb34fba..b4458eb9551 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -29,7 +29,7 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc DEPS glog) cc_test(scope_test SRCS scope_test.cc DEPS scope) -cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) +cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto) cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) cc_library(attribute SRCS attribute.cc DEPS framework_proto) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 376268888e7..58780e38635 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/framework/data_transform.h" #include "paddle/framework/lod_tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { @@ -23,5 +24,83 @@ DataTransformFnMap& DataTransformFnMap::Instance() { return data_transform_map; } +auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(), + DataLayout::kNHWC, LibraryType::kPlain); + +auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), + DataLayout::kNHWC, LibraryType::kPlain); + +auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), + DataLayout::kNHWC, LibraryType::kPlain); + +auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), + DataLayout::kNCHW, LibraryType::kPlain); + +void TransDataType(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataType Only Support DataType transform on same place!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + + auto dims = src.dims(); + dst->Resize(dims); + auto dst_type = kernel_pair.second.data_type_; + auto src_type = kernel_pair.first.data_type_; + + switch (src_type) { + case proto::DataType::FP32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::FP64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT32: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::INT64: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + case proto::DataType::BOOL: + framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + break; + default: + PADDLE_THROW("Not support type %d", src_type); + } +} + +void TransDataLayout(const platform::DeviceContext* ctx, + const KernelTypePair& kernel_pair, const Variable& in, + Variable* out) { + PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); + PADDLE_ENFORCE( + platform::places_are_same_class(kernel_pair.first.place_, + kernel_pair.second.place_), + "TransDataType Only Support DataType transform on same place!"); + + auto src = in.Get(); + auto* dst = out->GetMutable(); + PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); + + dst->Resize(src.dims()); + auto place = kernel_pair.second.place_; + CopyFrom(src, place, *ctx, dst); + const std::vector axis = {0, 2, 3, 1}; + + auto src_type = kernel_pair.first.data_type_; + framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis)); + + dst->set_layout(kernel_pair.second.data_layout_); +} + } // namespace framework } // namespace paddle + +namespace f = paddle::framework; +REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); +REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, f::TransDataLayout); diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index bd6d301c12e..9abb3c99bf3 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -21,16 +21,20 @@ limitations under the License. */ #include "paddle/framework/op_kernel_type.h" #include "paddle/framework/tensor.h" #include "paddle/framework/variable.h" +#include "paddle/operators/math/math_function.h" #include "paddle/platform/device_context.h" #include "paddle/platform/macros.h" +#include "paddle/platform/transform.h" namespace paddle { namespace framework { -using DataTransformFn = std::function; using KernelTypePair = std::pair; +using DataTransformFn = + std::function; + struct KernelTypePairHash { static void HashCombine(const OpKernelType& t, std::size_t* seed) { OpKernelType::Hash kernel_type_hasher; @@ -45,6 +49,65 @@ struct KernelTypePairHash { } }; +template +struct CastDataTypeFunctor { + HOSTDEVICE inline OutType operator()(InType in) const { + return static_cast(in); + } +}; + +template +struct CastDataType { + CastDataType(const framework::Tensor& in, framework::Tensor* out, + const platform::DeviceContext* ctx) + : in_(in), out_(out), ctx_(ctx) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + + auto* in_begin = in_.data(); + auto numel = in_.numel(); + auto* in_end = in_begin + numel; + auto* out_begin = out_->mutable_data(place); + if (platform::is_cpu_place(place)) { + platform::Transform trans; + auto* context = static_cast(ctx_); + trans(*context, in_begin, in_end, out_begin, + CastDataTypeFunctor()); + } else { + // TODO(dzhwinter): enhance CopyFrom CPU<->GPU with different data type? + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + +struct CastDataLayout { + CastDataLayout(const framework::Tensor& in, framework::Tensor* out, + const platform::DeviceContext* ctx, + const std::vector& axis) + : in_(in), out_(out), ctx_(ctx), axis_(axis) {} + const framework::Tensor in_; + framework::Tensor* out_; + const platform::DeviceContext* ctx_; + const std::vector axis_; + + template + void operator()() { + auto place = ctx_->GetPlace(); + if (platform::is_cpu_place(place)) { + operators::math::Transpose trans4; + auto* context = static_cast(ctx_); + trans4(*context, in_, out_, axis_); + } else { + PADDLE_THROW("Unsupport CPU <-> GPU!"); + } + } +}; + using DataTransformMap = std::unordered_map; diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 5f05e881fa1..5b01c8434b1 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/framework/data_transform.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { @@ -31,16 +32,18 @@ using namespace platform; * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN */ -std::array kDataType = { - {proto::DataType::FP32, proto::DataType::FP64}}; +std::array kDataType = {proto::DataType::FP32, + proto::DataType::FP64}; -std::array kPlace = {{CPUPlace(), CUDAPlace(0)}}; +std::array kPlace = {CPUPlace(), CUDAPlace(0)}; std::array kDataLayout = { - {DataLayout::kNHWC, DataLayout::kNCHW}}; + DataLayout::kNHWC, DataLayout::kNCHW, +}; std::array kLibraryType = { - {LibraryType::kPlain, LibraryType::kMKLDNN}}; + LibraryType::kPlain, LibraryType::kMKLDNN, +}; OpKernelType GenFromBit(const std::vector bits) { return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]], @@ -54,17 +57,20 @@ auto kernel1 = GenFromBit({0, 0, 0, 1}); auto kernel2 = GenFromBit({0, 0, 1, 0}); auto kernel3 = GenFromBit({0, 0, 1, 1}); -void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in, +void TransDataType_t(const platform::DeviceContext* ctx, + const KernelTypePair& p, const Variable& in, Variable* out) { test_value++; } -void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in, +void TransDataLayout_t(const platform::DeviceContext* ctx, + const KernelTypePair& p, const Variable& in, Variable* out) { test_value--; } -void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in, +void TransLibraryType_t(const platform::DeviceContext* ctx, + const KernelTypePair& p, const Variable& in, Variable* out) { test_value += 2; } @@ -83,17 +89,68 @@ TEST(DataTransform, Register) { using namespace paddle::platform; auto& instance = DataTransformFnMap::Instance(); - ASSERT_EQ(instance.Map().size(), 3UL); - DeviceContext* ctx = nullptr; paddle::framework::Variable in; paddle::framework::Variable out; - instance.Get(std::make_pair(frw::kernel0, frw::kernel1))(ctx, in, &out); + DeviceContext* ctx = new CPUDeviceContext(); + auto pair0 = std::make_pair(frw::kernel0, frw::kernel1); + instance.Get(pair0)(ctx, pair0, in, &out); ASSERT_EQ(test_value, 1); - instance.Get(std::make_pair(frw::kernel1, frw::kernel2))(ctx, in, &out); + auto pair1 = std::make_pair(frw::kernel1, frw::kernel2); + instance.Get(pair1)(ctx, pair1, in, &out); ASSERT_EQ(test_value, 0); - instance.Get(std::make_pair(frw::kernel0, frw::kernel2))(ctx, in, &out); + auto pair3 = std::make_pair(frw::kernel0, frw::kernel2); + instance.Get(pair3)(ctx, pair3, in, &out); ASSERT_EQ(test_value, 2); } + +TEST(DataTransform, Layout) { + using namespace paddle::framework; + using namespace paddle::platform; + + auto& instance = DataTransformFnMap::Instance(); + Variable in; + Variable out; + Tensor* src = in.GetMutable(); + src->mutable_data(make_ddim({2, 3, 1, 2}), CPUPlace()); + src->set_layout(DataLayout::kNHWC); + + DeviceContext* ctx = new CPUDeviceContext(); + + { + auto kernel1 = GenFromBit({1, 0, 0, 0}); + auto kernel2 = GenFromBit({1, 0, 1, 0}); + auto pair0 = std::make_pair(kernel1, kernel2); + instance.Get(pair0)(ctx, pair0, in, &out); + } + + Tensor dst = out.Get(); + EXPECT_TRUE(dst.layout() != src->layout()); +} + +TEST(DataTransform, DataType) { + using namespace paddle::framework; + using namespace paddle::platform; + + auto& instance = DataTransformFnMap::Instance(); + DeviceContext* ctx = new CPUDeviceContext(); + + Variable in; + Variable out; + Tensor* src = in.GetMutable(); + float* ptr = src->mutable_data(make_ddim({2, 3}), CPUPlace()); + for (int i = 0; i < 6; ++i) { + ptr[i] = i / 3; + } + + { + auto kernel1 = GenFromBit({0, 0, 0, 0}); + auto kernel2 = GenFromBit({1, 0, 0, 0}); + auto pair0 = std::make_pair(kernel1, kernel2); + instance.Get(pair0)(ctx, pair0, in, &out); + } + Tensor dst = out.Get(); + EXPECT_TRUE(dst.data() != nullptr); +} diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index a3ce96c4096..fc7091f1c89 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -461,7 +461,7 @@ void OperatorWithKernel::Run(const Scope& scope, dev_ctx->Wait(); for (auto var_name : need_trans) { - (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)), + (*trans_fun)(trans_dev_ctx, kernel_pair, *(scope.FindVar(var_name)), scope.FindVar(var_name + framework::KernelTypeToString( expected_kernel_key))); } diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index d4f12f0a106..dcf4b85e1aa 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -245,9 +245,12 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; -#define DEFINE_CPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; +#define DEFINE_CPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(2); -- GitLab