From 85671b8acbefe538845896148425450498890b7d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 21 Jan 2018 22:50:22 +0800 Subject: [PATCH] Data type transform (#7653) * init complete data layout transform * can compile * test passed * optimize code * fix while_grad_op first step loss lod problem * optimize in out ptr for transform * add check * update copyright * clean code * add NeedTransformLayout * add comment * change the interface of data_type_transform * init data_type_transform_test * complete data_type_transform_test * add TransDataType to data_transform --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/data_device_transform.cc | 2 +- paddle/framework/data_device_transform.h | 2 +- paddle/framework/data_transform.cc | 11 +++- paddle/framework/data_type_transform.cc | 42 ++++++---------- paddle/framework/data_type_transform.h | 7 +-- paddle/framework/data_type_transform_test.cc | 53 ++++++++++++++++++++ 7 files changed, 85 insertions(+), 33 deletions(-) create mode 100644 paddle/framework/data_type_transform_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8e5a956061..afb55bdaaa 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -37,6 +37,7 @@ nv_test(data_device_transform_test SRCS data_device_transform_test.cu DEPS operator op_registry init math_function) cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) +cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform) cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform) diff --git a/paddle/framework/data_device_transform.cc b/paddle/framework/data_device_transform.cc index d38d87927f..5daf5a4e0a 100644 --- a/paddle/framework/data_device_transform.cc +++ b/paddle/framework/data_device_transform.cc @@ -31,7 +31,7 @@ static const platform::DeviceContext* GetDeviceContext( } } -void DeviceTransform(const Tensor& in, const platform::Place& dst_place, +void TransDataDevice(const Tensor& in, const platform::Place& dst_place, Tensor* out) { VLOG(3) << "DeviceTransform in, src_place " << in.place() << " dst_place: " << dst_place; diff --git a/paddle/framework/data_device_transform.h b/paddle/framework/data_device_transform.h index b21ed0be34..39750a85f2 100644 --- a/paddle/framework/data_device_transform.h +++ b/paddle/framework/data_device_transform.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace framework { -void DeviceTransform(const Tensor& in, const platform::Place& dst_place, +void TransDataDevice(const Tensor& in, const platform::Place& dst_place, Tensor* out); } // namespace framework diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index e28b2e015d..b6fd46401f 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/data_device_transform.h" #include "paddle/framework/data_layout_transform.h" +#include "paddle/framework/data_type_transform.h" namespace paddle { namespace framework { @@ -41,15 +42,21 @@ void DataTransform(const OpKernelType& expected_kernel_type, PassTensorData(&out, &in); } + if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { + TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); + transformed = true; + PassTensorData(&out, &in); + } + // do device transform if (!platform::is_same_place(kernel_type_for_var.place_, expected_kernel_type.place_)) { - DeviceTransform(in, expected_kernel_type.place_, &out); + TransDataDevice(in, expected_kernel_type.place_, &out); transformed = true; PassTensorData(&out, &in); } - PADDLE_ENFORCE(transformed, "no transform is done, please check!"); + PADDLE_ENFORCE(transformed, "No transform is applied, please check!"); // get output data output_tensor->ShareDataWith(in); } diff --git a/paddle/framework/data_type_transform.cc b/paddle/framework/data_type_transform.cc index 63373232e9..7df1cc6b75 100644 --- a/paddle/framework/data_type_transform.cc +++ b/paddle/framework/data_type_transform.cc @@ -38,14 +38,11 @@ struct CastDataType { template void operator()() { - auto place = ctx_->GetPlace(); - auto* in_begin = in_.data(); - auto numel = in_.numel(); - auto* in_end = in_begin + numel; - auto* out_begin = out_->mutable_data(place); + auto* in_end = in_begin + in_.numel(); + auto* out_begin = out_->mutable_data(in_.place()); - if (platform::is_cpu_place(place)) { + if (platform::is_cpu_place(in_.place())) { platform::Transform trans; auto* context = static_cast(ctx_); trans(*context, in_begin, in_end, out_begin, @@ -57,38 +54,31 @@ struct CastDataType { } }; -void TransDataType(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out) { - PADDLE_ENFORCE(in.IsType(), "Only Support Tensor transform!."); - PADDLE_ENFORCE( - platform::places_are_same_class(kernel_pair.first.place_, - kernel_pair.second.place_), - "TransDataType Only Support DataType transform on same place!"); - - auto src = in.Get(); - auto* dst = out->GetMutable(); +void TransDataType(const OpKernelType& kernel_type_for_var, + const OpKernelType& expected_kernel_type, const Tensor& in, + Tensor* out) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto dims = src.dims(); - dst->Resize(dims); - auto dst_type = kernel_pair.second.data_type_; - auto src_type = kernel_pair.first.data_type_; + out->Resize(in.dims()); + auto src_type = kernel_type_for_var.data_type_; + auto dst_type = expected_kernel_type.data_type_; + auto ctx = pool.Get(in.place()); switch (src_type) { case proto::DataType::FP32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::DataType::FP64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::DataType::INT32: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::DataType::INT64: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; case proto::DataType::BOOL: - framework::VisitDataType(dst_type, CastDataType(src, dst, ctx)); + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; default: PADDLE_THROW("Not support type %d", src_type); diff --git a/paddle/framework/data_type_transform.h b/paddle/framework/data_type_transform.h index 8ec9074225..067c0c2a5b 100644 --- a/paddle/framework/data_type_transform.h +++ b/paddle/framework/data_type_transform.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_kernel_type.h" +#include "paddle/framework/tensor.h" #include "paddle/framework/variable.h" #include "paddle/platform/device_context.h" @@ -23,9 +24,9 @@ namespace framework { using KernelTypePair = std::pair; -void TransDataType(const platform::DeviceContext* ctx, - const KernelTypePair& kernel_pair, const Variable& in, - Variable* out); +void TransDataType(const OpKernelType& kernel_type_for_var, + const OpKernelType& expected_kernel_type, const Tensor& in, + Tensor* out); } // namespace framework } // namespace paddle diff --git a/paddle/framework/data_type_transform_test.cc b/paddle/framework/data_type_transform_test.cc new file mode 100644 index 0000000000..89d32f5283 --- /dev/null +++ b/paddle/framework/data_type_transform_test.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_type_transform.h" + +#include "gtest/gtest.h" + +TEST(DataTypeTransform, CPUTransform) { + using namespace paddle::framework; + using namespace paddle::platform; + + auto place = CPUPlace(); + + Tensor in; + Tensor out; + + float* ptr = in.mutable_data(make_ddim({2, 3}), place); + int data_number = 2 * 3; + + for (int i = 0; i < data_number; ++i) { + ptr[i] = i / 3; + } + + auto kernel_fp32 = OpKernelType(proto::DataType::FP32, place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_fp64 = OpKernelType(proto::DataType::FP64, place, + DataLayout::kAnyLayout, LibraryType::kPlain); + auto kernel_int32 = OpKernelType(proto::DataType::INT32, place, + DataLayout::kAnyLayout, LibraryType::kPlain); + + TransDataType(kernel_fp32, kernel_fp64, in, &out); + double* out_data_double = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_double[i], static_cast(i / 3)); + } + + TransDataType(kernel_fp32, kernel_int32, in, &out); + int* out_data_int = out.data(); + for (int i = 0; i < data_number; ++i) { + ASSERT_EQ(out_data_int[i], static_cast(i / 3)); + } +} -- GitLab