From a893f156527942d9172d51ab6662748b7000d5bc Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 4 Jan 2018 13:30:51 +0800 Subject: [PATCH] fix layout transform (#7149) * "fix typo" * "fix based on comments" * "follow gogle style" * "fix based on comemnts" --- paddle/framework/data_transform.cc | 36 +++++++++++++++++++------ paddle/framework/data_transform.h | 8 +++--- paddle/framework/data_transform_test.cc | 16 +++++++++-- 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 9d6a84244..ac6e40a3a 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/framework/data_transform.h" #include "paddle/framework/lod_tensor.h" @@ -74,26 +75,28 @@ void TransDataType(const platform::DeviceContext* ctx, } } -void TransDataLayout(const platform::DeviceContext* ctx, +void TransDataLayout(const std::vector& axis, + const platform::DeviceContext* ctx, const KernelTypePair& kernel_pair, const Variable& in, Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); + PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); PADDLE_ENFORCE( platform::places_are_same_class(kernel_pair.first.place_, kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); + "TransDataLayout only support DataLayout transform on same place!"); + PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, + "TransDataLayout only support Datatype are same!"); auto src = in.Get(); auto* dst = out->GetMutable(); PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); - auto src_dim = src.dims(); - dst->Resize(src_dim); auto place = kernel_pair.second.place_; CopyFrom(src, place, *ctx, dst); - const std::vector axis = {0, 2, 3, 1}; + auto src_dim = src.dims(); std::vector dst_dim; + dst_dim.resize(axis.size()); for (size_t i = 0; i < axis.size(); i++) { dst_dim[i] = src_dim[axis[i]]; @@ -102,7 +105,7 @@ void TransDataLayout(const platform::DeviceContext* ctx, dst->Resize(make_ddim(dst_dim)); auto src_type = kernel_pair.first.data_type_; - framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis)); + framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); dst->set_layout(kernel_pair.second.data_layout_); } @@ -111,5 +114,22 @@ void TransDataLayout(const platform::DeviceContext* ctx, } // namespace paddle namespace f = paddle::framework; + +namespace { +std::vector NHWC2NCHW = {0, 3, 1, 2}; +std::vector NCHW2NHWC = {0, 2, 3, 1}; +} + REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); -REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, f::TransDataLayout); +REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, + std::bind(f::TransDataLayout, NHWC2NCHW, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3, + std::placeholders::_4)); +REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC, + std::bind(f::TransDataLayout, NCHW2NHWC, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3, + std::placeholders::_4)); diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 9abb3c99b..56ebc80f4 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -73,6 +73,7 @@ struct CastDataType { auto numel = in_.numel(); auto* in_end = in_begin + numel; auto* out_begin = out_->mutable_data(place); + if (platform::is_cpu_place(place)) { platform::Transform trans; auto* context = static_cast(ctx_); @@ -86,9 +87,9 @@ struct CastDataType { }; struct CastDataLayout { - CastDataLayout(const framework::Tensor& in, framework::Tensor* out, - const platform::DeviceContext* ctx, - const std::vector& axis) + CastDataLayout(const platform::DeviceContext* ctx, + const std::vector& axis, const framework::Tensor& in, + framework::Tensor* out) : in_(in), out_(out), ctx_(ctx), axis_(axis) {} const framework::Tensor in_; framework::Tensor* out_; @@ -98,6 +99,7 @@ struct CastDataLayout { template void operator()() { auto place = ctx_->GetPlace(); + if (platform::is_cpu_place(place)) { operators::math::Transpose trans4; auto* context = static_cast(ctx_); diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 8665b6248..edd305fd1 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -106,7 +106,7 @@ TEST(DataTransform, Register) { ASSERT_EQ(test_value, 2); } -TEST(DataTransform, Layout) { +TEST(DataTransform, DataLayout) { using namespace paddle::framework; using namespace paddle::platform; @@ -127,7 +127,19 @@ TEST(DataTransform, Layout) { } Tensor dst = out.Get(); - EXPECT_TRUE(dst.layout() != src->layout()); + + EXPECT_TRUE(dst.layout() == DataLayout::kNCHW); + EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1})); + + { + auto kernel1 = GenFromBit({1, 0, 1, 0}); + auto kernel2 = GenFromBit({1, 0, 0, 0}); + auto pair0 = std::make_pair(kernel1, kernel2); + instance.Get(pair0)(ctx, pair0, out, &in); + } + + EXPECT_TRUE(src->layout() == DataLayout::kNHWC); + EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2})); } TEST(DataTransform, DataType) { -- GitLab