未验证 提交 0071b5f7 编写于 作者: Q Qiao Longfei 提交者: GitHub

complete data layout transform (#7440)

* add data layout transform and optimize the implementation of data_transform
上级 9e17c46c
...@@ -33,8 +33,13 @@ cc_library(scope SRCS scope.cc DEPS glog threadpool) ...@@ -33,8 +33,13 @@ cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function)
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
framework_proto selected_rows data_device_transform data_type_transform data_layout_transform) framework_proto selected_rows data_device_transform data_type_transform data_layout_transform)
...@@ -82,5 +87,3 @@ cc_test(init_test SRCS init_test.cc DEPS init) ...@@ -82,5 +87,3 @@ cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function)
...@@ -150,6 +150,7 @@ TEST(Operator, CPUtoGPU) { ...@@ -150,6 +150,7 @@ TEST(Operator, CPUtoGPU) {
// get output // get output
auto* output2 = scope.Var("OUT2"); auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place); gpu_op->Run(scope, cuda_place);
VLOG(3) << "after gpu_op run";
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>(); // auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance(); DeviceContextPool& pool = DeviceContextPool::Instance();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -14,12 +14,23 @@ limitations under the License. */ ...@@ -14,12 +14,23 @@ limitations under the License. */
#include "paddle/framework/data_layout_transform.h" #include "paddle/framework/data_layout_transform.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
PADDLE_ENFORCE_NE(from, to,
"layout transform should transform different layout");
if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) {
return {0, 2, 3, 1};
} else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) {
return {0, 3, 1, 2};
} else {
PADDLE_THROW("unsupported transform");
}
}
struct CastDataLayout { struct CastDataLayout {
CastDataLayout(const platform::DeviceContext* ctx, CastDataLayout(const platform::DeviceContext* ctx,
const std::vector<int>& axis, const framework::Tensor& in, const std::vector<int>& axis, const framework::Tensor& in,
...@@ -44,38 +55,36 @@ struct CastDataLayout { ...@@ -44,38 +55,36 @@ struct CastDataLayout {
} }
}; };
void TransDataLayout(const std::vector<int>& axis, void TransDataLayout(const OpKernelType& kernel_type_for_var,
const platform::DeviceContext* ctx, const OpKernelType& expected_kernel_type, const Tensor& in,
const KernelTypePair& kernel_pair, const Variable& in, Tensor* out) {
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_, platform::places_are_same_class(kernel_type_for_var.place_,
kernel_pair.second.place_), expected_kernel_type.place_),
"TransDataLayout only support DataLayout transform on same place!"); "TransDataLayout only support DataLayout transform on same place!");
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
"TransDataLayout only support Datatype are same!");
auto src = in.Get<Tensor>(); PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!");
auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); auto& pool = platform::DeviceContextPool::Instance();
auto src_dim = src.dims(); auto src_dim = in.dims();
std::vector<int64_t> dst_dim; std::vector<int64_t> dst_dim;
auto axis = GetAxis(kernel_type_for_var.data_layout_,
expected_kernel_type.data_layout_);
dst_dim.resize(axis.size()); dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]]; dst_dim[i] = src_dim[axis[i]];
} }
dst->Resize(make_ddim(dst_dim)); out->Resize(make_ddim(dst_dim));
auto place = kernel_pair.second.place_; out->mutable_data(expected_kernel_type.place_, in.type());
dst->mutable_data(place, src.type());
auto src_type = kernel_pair.first.data_type_; framework::VisitDataType(
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst)); framework::ToDataType(in.type()),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
dst->set_layout(kernel_pair.second.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
} }
} // namespace framework } // namespace framework
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -15,17 +15,17 @@ limitations under the License. */ ...@@ -15,17 +15,17 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_kernel_type.h" #include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using KernelTypePair = std::pair<OpKernelType, OpKernelType>; std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const std::vector<int>& axis, void TransDataLayout(const OpKernelType& kernel_type_for_var,
const platform::DeviceContext* ctx, const OpKernelType& expected_kernel_type, const Tensor& in,
const KernelTypePair& kernel_pair, const Variable& in, Tensor* out);
Variable* out);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_layout_transform.h"
#include "gtest/gtest.h"
#include "paddle/platform/device_context.h"
TEST(DataTransform, DataLayoutFunction) {
using namespace paddle::framework;
using namespace paddle::platform;
auto place = CPUPlace();
Tensor in = Tensor();
Tensor out = Tensor();
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
EXPECT_TRUE(out.layout() == DataLayout::kNCHW);
EXPECT_TRUE(out.dims() == make_ddim({2, 2, 3, 1}));
TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out);
EXPECT_TRUE(in.layout() == DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == make_ddim({2, 3, 1, 2}));
}
\ No newline at end of file
...@@ -15,18 +15,43 @@ limitations under the License. */ ...@@ -15,18 +15,43 @@ limitations under the License. */
#include "paddle/framework/data_transform.h" #include "paddle/framework/data_transform.h"
#include "paddle/framework/data_device_transform.h" #include "paddle/framework/data_device_transform.h"
#include "paddle/framework/data_layout_transform.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static void PassTensorData(Tensor* from, Tensor* to) {
to->ShareDataWith(*from);
*from = Tensor();
}
void DataTransform(const OpKernelType& expected_kernel_type, void DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var, const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor, Tensor* out) { const Tensor& input_tensor, Tensor* output_tensor) {
bool transformed = false;
Tensor in;
in.ShareDataWith(input_tensor);
Tensor out;
// do layout transform
if (NeedTransformLayout(expected_kernel_type.data_layout_,
kernel_type_for_var.data_layout_)) {
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true;
PassTensorData(&out, &in);
}
// do device transform
if (!platform::is_same_place(kernel_type_for_var.place_, if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) { expected_kernel_type.place_)) {
DeviceTransform(input_tensor, expected_kernel_type.place_, out); DeviceTransform(in, expected_kernel_type.place_, &out);
transformed = true;
PassTensorData(&out, &in);
} }
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
PADDLE_ENFORCE(transformed, "no transform is done, please check!");
// get output data
output_tensor->ShareDataWith(in);
} }
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
......
...@@ -85,9 +85,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) { ...@@ -85,9 +85,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
return stream.str(); return stream.str();
} }
inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r;
}
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) { inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) {
return (!platform::places_are_same_class(l.place_, r.place_)) || return (!platform::places_are_same_class(l.place_, r.place_)) ||
(l.data_type_ != r.data_type_) || (l.data_layout_ != r.data_layout_); (l.data_type_ != r.data_type_) ||
NeedTransformLayout(l.data_layout_, r.data_layout_);
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册