diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 597ea959f230d88350796cef05b7d6f2a42e594a..8e5a956061f459a4e08acea0e83f7719a44fb014 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -33,8 +33,13 @@ cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) +nv_test(data_device_transform_test SRCS data_device_transform_test.cu + DEPS operator op_registry init math_function) + cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) + cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) +cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform) cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows data_device_transform data_type_transform data_layout_transform) @@ -82,5 +87,3 @@ cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) -nv_test(data_device_transform_test SRCS data_device_transform_test.cu - DEPS operator op_registry init math_function) diff --git a/paddle/framework/data_device_transform_test.cu b/paddle/framework/data_device_transform_test.cu index 5d89f5546fa87241dec6364d86a100ca51bce687..efc05b3106b40bdaa6cd03ce707c677dd58b0730 100644 --- a/paddle/framework/data_device_transform_test.cu +++ b/paddle/framework/data_device_transform_test.cu @@ -150,6 +150,7 @@ TEST(Operator, CPUtoGPU) { // get output auto* output2 = scope.Var("OUT2"); gpu_op->Run(scope, cuda_place); + VLOG(3) << "after gpu_op run"; // auto* output2_ptr = output2->Get().data(); DeviceContextPool& pool = DeviceContextPool::Instance(); diff --git a/paddle/framework/data_layout_transform.cc b/paddle/framework/data_layout_transform.cc index 96794cae97d460e86fe83ac1395e1dfc7e371e3b..1059bd976180621780dccd2b0b58a3c9ad5a09c9 100644 --- a/paddle/framework/data_layout_transform.cc +++ b/paddle/framework/data_layout_transform.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,12 +14,23 @@ limitations under the License. */ #include "paddle/framework/data_layout_transform.h" -#include "paddle/framework/tensor.h" #include "paddle/operators/math/math_function.h" namespace paddle { namespace framework { +std::vector GetAxis(const DataLayout& from, const DataLayout& to) { + PADDLE_ENFORCE_NE(from, to, + "layout transform should transform different layout"); + if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) { + return {0, 2, 3, 1}; + } else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) { + return {0, 3, 1, 2}; + } else { + PADDLE_THROW("unsupported transform"); + } +} + struct CastDataLayout { CastDataLayout(const platform::DeviceContext* ctx, const std::vector& axis, const framework::Tensor& in, @@ -44,38 +55,36 @@ struct CastDataLayout { } }; -void TransDataLayout(const std::vector& axis, - const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only support Tensor transform!."); +void TransDataLayout(const OpKernelType& kernel_type_for_var, + const OpKernelType& expected_kernel_type, const Tensor& in, + Tensor* out) { PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), + platform::places_are_same_class(kernel_type_for_var.place_, + expected_kernel_type.place_), "TransDataLayout only support DataLayout transform on same place!"); - PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_, - "TransDataLayout only support Datatype are same!"); - auto src = in.Get(); - auto* dst = out->GetMutable(); - PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); + PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!"); + + auto& pool = platform::DeviceContextPool::Instance(); - auto src_dim = src.dims(); + auto src_dim = in.dims(); std::vector dst_dim; + auto axis = GetAxis(kernel_type_for_var.data_layout_, + expected_kernel_type.data_layout_); dst_dim.resize(axis.size()); for (size_t i = 0; i < axis.size(); i++) { dst_dim[i] = src_dim[axis[i]]; } - dst->Resize(make_ddim(dst_dim)); - auto place = kernel_pair.second.place_; - dst->mutable_data(place, src.type()); + out->Resize(make_ddim(dst_dim)); + out->mutable_data(expected_kernel_type.place_, in.type()); - auto src_type = kernel_pair.first.data_type_; - framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); + framework::VisitDataType( + framework::ToDataType(in.type()), + CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); - dst->set_layout(kernel_pair.second.data_layout_); + out->set_layout(expected_kernel_type.data_layout_); } } // namespace framework diff --git a/paddle/framework/data_layout_transform.h b/paddle/framework/data_layout_transform.h index befae1f63616a4c21d998c6b784b8ef288d00617..ec87257d7020425b9f8deadc9da144d3d5631b86 100644 --- a/paddle/framework/data_layout_transform.h +++ b/paddle/framework/data_layout_transform.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,17 +15,17 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/tensor.h" #include "paddle/framework/variable.h" namespace paddle { namespace framework { -using KernelTypePair = std::pair; +std::vector GetAxis(const DataLayout& from, const DataLayout& to); -void TransDataLayout(const std::vector& axis, - const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out); +void TransDataLayout(const OpKernelType& kernel_type_for_var, + const OpKernelType& expected_kernel_type, const Tensor& in, + Tensor* out); } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_layout_transform_test.cc b/paddle/framework/data_layout_transform_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2eb99fa04a1daca71f51e047fef0bb87114510aa --- /dev/null +++ b/paddle/framework/data_layout_transform_test.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_layout_transform.h" + +#include "gtest/gtest.h" +#include "paddle/platform/device_context.h" + +TEST(DataTransform, DataLayoutFunction) { + using namespace paddle::framework; + using namespace paddle::platform; + + auto place = CPUPlace(); + Tensor in = Tensor(); + Tensor out = Tensor(); + in.mutable_data(make_ddim({2, 3, 1, 2}), place); + in.set_layout(DataLayout::kNHWC); + + auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place, + DataLayout::kNHWC, LibraryType::kPlain); + auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place, + DataLayout::kNCHW, LibraryType::kPlain); + + TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out); + + EXPECT_TRUE(out.layout() == DataLayout::kNCHW); + EXPECT_TRUE(out.dims() == make_ddim({2, 2, 3, 1})); + + TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out); + + EXPECT_TRUE(in.layout() == DataLayout::kNHWC); + EXPECT_TRUE(in.dims() == make_ddim({2, 3, 1, 2})); +} \ No newline at end of file diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index d826f0edace6d5afee5cd83f6e65d6dbaefae874..e28b2e015d62729f29a4cd04db9d3480a6237e51 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -15,18 +15,43 @@ limitations under the License. */ #include "paddle/framework/data_transform.h" #include "paddle/framework/data_device_transform.h" +#include "paddle/framework/data_layout_transform.h" namespace paddle { namespace framework { +static void PassTensorData(Tensor* from, Tensor* to) { + to->ShareDataWith(*from); + *from = Tensor(); +} + void DataTransform(const OpKernelType& expected_kernel_type, const OpKernelType& kernel_type_for_var, - const Tensor& input_tensor, Tensor* out) { + const Tensor& input_tensor, Tensor* output_tensor) { + bool transformed = false; + Tensor in; + in.ShareDataWith(input_tensor); + Tensor out; + + // do layout transform + if (NeedTransformLayout(expected_kernel_type.data_layout_, + kernel_type_for_var.data_layout_)) { + TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); + transformed = true; + PassTensorData(&out, &in); + } + + // do device transform if (!platform::is_same_place(kernel_type_for_var.place_, expected_kernel_type.place_)) { - DeviceTransform(input_tensor, expected_kernel_type.place_, out); + DeviceTransform(in, expected_kernel_type.place_, &out); + transformed = true; + PassTensorData(&out, &in); } - PADDLE_ENFORCE_NOT_NULL(out, "out should not be null"); + + PADDLE_ENFORCE(transformed, "no transform is done, please check!"); + // get output data + output_tensor->ShareDataWith(in); } void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h index 312bd5f892ac23c847c87388c9cadf2161028d3e..44adb94d2a8feb79a5ff93c6e32cdff52333166e 100644 --- a/paddle/framework/op_kernel_type.h +++ b/paddle/framework/op_kernel_type.h @@ -85,9 +85,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) { return stream.str(); } +inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { + return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r; +} + inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) { return (!platform::places_are_same_class(l.place_, r.place_)) || - (l.data_type_ != r.data_type_) || (l.data_layout_ != r.data_layout_); + (l.data_type_ != r.data_type_) || + NeedTransformLayout(l.data_layout_, r.data_layout_); } } // namespace framework