提交 85671b8a 编写于 作者: Q Qiao Longfei 提交者: dzhwinter

Data type transform (#7653)

* init complete data layout transform

* can compile

* test passed

* optimize code

* fix while_grad_op first step loss lod problem

* optimize in out ptr for transform

* add check

* update copyright

* clean code

* add NeedTransformLayout

* add comment

* change the interface of data_type_transform

* init data_type_transform_test

* complete data_type_transform_test

* add TransDataType to data_transform
上级 02add30c
...@@ -37,6 +37,7 @@ nv_test(data_device_transform_test SRCS data_device_transform_test.cu ...@@ -37,6 +37,7 @@ nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function) DEPS operator op_registry init math_function)
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform)
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform) cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)
......
...@@ -31,7 +31,7 @@ static const platform::DeviceContext* GetDeviceContext( ...@@ -31,7 +31,7 @@ static const platform::DeviceContext* GetDeviceContext(
} }
} }
void DeviceTransform(const Tensor& in, const platform::Place& dst_place, void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
Tensor* out) { Tensor* out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place() VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void DeviceTransform(const Tensor& in, const platform::Place& dst_place, void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
Tensor* out); Tensor* out);
} // namespace framework } // namespace framework
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/data_device_transform.h" #include "paddle/framework/data_device_transform.h"
#include "paddle/framework/data_layout_transform.h" #include "paddle/framework/data_layout_transform.h"
#include "paddle/framework/data_type_transform.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,15 +42,21 @@ void DataTransform(const OpKernelType& expected_kernel_type, ...@@ -41,15 +42,21 @@ void DataTransform(const OpKernelType& expected_kernel_type,
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true;
PassTensorData(&out, &in);
}
// do device transform // do device transform
if (!platform::is_same_place(kernel_type_for_var.place_, if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) { expected_kernel_type.place_)) {
DeviceTransform(in, expected_kernel_type.place_, &out); TransDataDevice(in, expected_kernel_type.place_, &out);
transformed = true; transformed = true;
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
PADDLE_ENFORCE(transformed, "no transform is done, please check!"); PADDLE_ENFORCE(transformed, "No transform is applied, please check!");
// get output data // get output data
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
......
...@@ -38,14 +38,11 @@ struct CastDataType { ...@@ -38,14 +38,11 @@ struct CastDataType {
template <typename OutType> template <typename OutType>
void operator()() { void operator()() {
auto place = ctx_->GetPlace();
auto* in_begin = in_.data<InType>(); auto* in_begin = in_.data<InType>();
auto numel = in_.numel(); auto* in_end = in_begin + in_.numel();
auto* in_end = in_begin + numel; auto* out_begin = out_->mutable_data<OutType>(in_.place());
auto* out_begin = out_->mutable_data<OutType>(place);
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(in_.place())) {
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin, trans(*context, in_begin, in_end, out_begin,
...@@ -57,38 +54,31 @@ struct CastDataType { ...@@ -57,38 +54,31 @@ struct CastDataType {
} }
}; };
void TransDataType(const platform::DeviceContext* ctx, void TransDataType(const OpKernelType& kernel_type_for_var,
const KernelTypePair& kernel_pair, const Variable& in, const OpKernelType& expected_kernel_type, const Tensor& in,
Variable* out) { Tensor* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!."); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
auto dims = src.dims(); out->Resize(in.dims());
dst->Resize(dims); auto src_type = kernel_type_for_var.data_type_;
auto dst_type = kernel_pair.second.data_type_; auto dst_type = expected_kernel_type.data_type_;
auto src_type = kernel_pair.first.data_type_; auto ctx = pool.Get(in.place());
switch (src_type) { switch (src_type) {
case proto::DataType::FP32: case proto::DataType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(src, dst, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
case proto::DataType::FP64: case proto::DataType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(src, dst, ctx)); framework::VisitDataType(dst_type, CastDataType<double>(in, out, ctx));
break; break;
case proto::DataType::INT32: case proto::DataType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx)); framework::VisitDataType(dst_type, CastDataType<int>(in, out, ctx));
break; break;
case proto::DataType::INT64: case proto::DataType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx)); framework::VisitDataType(dst_type, CastDataType<int64_t>(in, out, ctx));
break; break;
case proto::DataType::BOOL: case proto::DataType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(src, dst, ctx)); framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break; break;
default: default:
PADDLE_THROW("Not support type %d", src_type); PADDLE_THROW("Not support type %d", src_type);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_kernel_type.h" #include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -23,9 +24,9 @@ namespace framework { ...@@ -23,9 +24,9 @@ namespace framework {
using KernelTypePair = std::pair<OpKernelType, OpKernelType>; using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
void TransDataType(const platform::DeviceContext* ctx, void TransDataType(const OpKernelType& kernel_type_for_var,
const KernelTypePair& kernel_pair, const Variable& in, const OpKernelType& expected_kernel_type, const Tensor& in,
Variable* out); Tensor* out);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_type_transform.h"
#include "gtest/gtest.h"
TEST(DataTypeTransform, CPUTransform) {
using namespace paddle::framework;
using namespace paddle::platform;
auto place = CPUPlace();
Tensor in;
Tensor out;
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i / 3;
}
auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::DataType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3));
}
TransDataType(kernel_fp32, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册