未验证 提交 0071b5f7 编写于 作者: Q Qiao Longfei 提交者: GitHub

complete data layout transform (#7440)

* add data layout transform and optimize the implementation of data_transform
上级 9e17c46c
......@@ -33,8 +33,13 @@ cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function)
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
framework_proto selected_rows data_device_transform data_type_transform data_layout_transform)
......@@ -82,5 +87,3 @@ cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function)
......@@ -150,6 +150,7 @@ TEST(Operator, CPUtoGPU) {
// get output
auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place);
VLOG(3) << "after gpu_op run";
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -14,12 +14,23 @@ limitations under the License. */
#include "paddle/framework/data_layout_transform.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace framework {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
PADDLE_ENFORCE_NE(from, to,
"layout transform should transform different layout");
if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) {
return {0, 2, 3, 1};
} else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) {
return {0, 3, 1, 2};
} else {
PADDLE_THROW("unsupported transform");
}
}
struct CastDataLayout {
CastDataLayout(const platform::DeviceContext* ctx,
const std::vector<int>& axis, const framework::Tensor& in,
......@@ -44,38 +55,36 @@ struct CastDataLayout {
}
};
void TransDataLayout(const std::vector<int>& axis,
const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
void TransDataLayout(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const Tensor& in,
Tensor* out) {
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
platform::places_are_same_class(kernel_type_for_var.place_,
expected_kernel_type.place_),
"TransDataLayout only support DataLayout transform on same place!");
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
"TransDataLayout only support Datatype are same!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!");
auto& pool = platform::DeviceContextPool::Instance();
auto src_dim = src.dims();
auto src_dim = in.dims();
std::vector<int64_t> dst_dim;
auto axis = GetAxis(kernel_type_for_var.data_layout_,
expected_kernel_type.data_layout_);
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
dst->Resize(make_ddim(dst_dim));
auto place = kernel_pair.second.place_;
dst->mutable_data(place, src.type());
out->Resize(make_ddim(dst_dim));
out->mutable_data(expected_kernel_type.place_, in.type());
auto src_type = kernel_pair.first.data_type_;
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
framework::VisitDataType(
framework::ToDataType(in.type()),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
dst->set_layout(kernel_pair.second.data_layout_);
out->set_layout(expected_kernel_type.data_layout_);
}
} // namespace framework
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -15,17 +15,17 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
namespace paddle {
namespace framework {
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const std::vector<int>& axis,
const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out);
void TransDataLayout(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const Tensor& in,
Tensor* out);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_layout_transform.h"
#include "gtest/gtest.h"
#include "paddle/platform/device_context.h"
TEST(DataTransform, DataLayoutFunction) {
using namespace paddle::framework;
using namespace paddle::platform;
auto place = CPUPlace();
Tensor in = Tensor();
Tensor out = Tensor();
in.mutable_data<double>(make_ddim({2, 3, 1, 2}), place);
in.set_layout(DataLayout::kNHWC);
auto kernel_nhwc = OpKernelType(proto::DataType::FP32, place,
DataLayout::kNHWC, LibraryType::kPlain);
auto kernel_ncwh = OpKernelType(proto::DataType::FP32, place,
DataLayout::kNCHW, LibraryType::kPlain);
TransDataLayout(kernel_nhwc, kernel_ncwh, in, &out);
EXPECT_TRUE(out.layout() == DataLayout::kNCHW);
EXPECT_TRUE(out.dims() == make_ddim({2, 2, 3, 1}));
TransDataLayout(kernel_ncwh, kernel_nhwc, in, &out);
EXPECT_TRUE(in.layout() == DataLayout::kNHWC);
EXPECT_TRUE(in.dims() == make_ddim({2, 3, 1, 2}));
}
\ No newline at end of file
......@@ -15,18 +15,43 @@ limitations under the License. */
#include "paddle/framework/data_transform.h"
#include "paddle/framework/data_device_transform.h"
#include "paddle/framework/data_layout_transform.h"
namespace paddle {
namespace framework {
static void PassTensorData(Tensor* from, Tensor* to) {
to->ShareDataWith(*from);
*from = Tensor();
}
void DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor, Tensor* out) {
const Tensor& input_tensor, Tensor* output_tensor) {
bool transformed = false;
Tensor in;
in.ShareDataWith(input_tensor);
Tensor out;
// do layout transform
if (NeedTransformLayout(expected_kernel_type.data_layout_,
kernel_type_for_var.data_layout_)) {
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true;
PassTensorData(&out, &in);
}
// do device transform
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
DeviceTransform(input_tensor, expected_kernel_type.place_, out);
DeviceTransform(in, expected_kernel_type.place_, &out);
transformed = true;
PassTensorData(&out, &in);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
PADDLE_ENFORCE(transformed, "no transform is done, please check!");
// get output data
output_tensor->ShareDataWith(in);
}
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
......
......@@ -85,9 +85,14 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
return stream.str();
}
inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
return l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r;
}
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) {
return (!platform::places_are_same_class(l.place_, r.place_)) ||
(l.data_type_ != r.data_type_) || (l.data_layout_ != r.data_layout_);
(l.data_type_ != r.data_type_) ||
NeedTransformLayout(l.data_layout_, r.data_layout_);
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册