未验证 提交 377424bf 编写于 作者: Q Qiao Longfei 提交者: GitHub

reorganize data transform related code (#7391)

* init data_type_transform

* split data_layout_transform

* tmp rm data_transform_test

* change device_data_transform to data_device_transform

* clean code

* clean code
上级 6cff3c96
......@@ -32,10 +32,12 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
framework_proto selected_rows data_device_transform data_type_transform data_layout_transform)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
......@@ -80,5 +82,5 @@ cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function)
......@@ -11,7 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/data_device_transform.h"
namespace paddle {
namespace framework {
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/enforce.h"
#include <iostream>
#include "paddle/platform/enforce.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_layout_transform.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace framework {
struct CastDataLayout {
CastDataLayout(const platform::DeviceContext* ctx,
const std::vector<int>& axis, const framework::Tensor& in,
framework::Tensor* out)
: in_(in), out_(out), ctx_(ctx), axis_(axis) {}
const framework::Tensor in_;
framework::Tensor* out_;
const platform::DeviceContext* ctx_;
const std::vector<int> axis_;
template <typename T>
void operator()() {
auto place = ctx_->GetPlace();
if (platform::is_cpu_place(place)) {
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans4(*context, in_, out_, axis_);
} else {
PADDLE_THROW("Unsupport CPU <-> GPU!");
}
}
};
void TransDataLayout(const std::vector<int>& axis,
const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataLayout only support DataLayout transform on same place!");
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
"TransDataLayout only support Datatype are same!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
auto src_dim = src.dims();
std::vector<int64_t> dst_dim;
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
dst->Resize(make_ddim(dst_dim));
auto place = kernel_pair.second.place_;
dst->mutable_data(place, src.type());
auto src_type = kernel_pair.first.data_type_;
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
dst->set_layout(kernel_pair.second.data_layout_);
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/variable.h"
namespace paddle {
namespace framework {
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
void TransDataLayout(const std::vector<int>& axis,
const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out);
} // namespace framework
} // namespace paddle
......@@ -11,22 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <functional>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
#include "paddle/framework/data_device_transform.h"
namespace paddle {
namespace framework {
DataTransformFnMap& DataTransformFnMap::Instance() {
static DataTransformFnMap data_transform_map;
return data_transform_map;
}
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
......@@ -58,134 +50,5 @@ void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
}
}
auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);
auto KernelFP64 = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);
auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);
auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
DataLayout::kNCHW, LibraryType::kPlain);
// TODO(dzhwinter): Only for testing multiple op kernel.
// Dummy transform function for library_type
// should be removed.
auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout, LibraryType::kPlain);
auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout, LibraryType::kCUDNN);
void DummyTrans(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
*dst = src;
}
void TransDataType(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
auto dims = src.dims();
dst->Resize(dims);
auto dst_type = kernel_pair.second.data_type_;
auto src_type = kernel_pair.first.data_type_;
switch (src_type) {
case proto::DataType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(src, dst, ctx));
break;
case proto::DataType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(src, dst, ctx));
break;
case proto::DataType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
break;
case proto::DataType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx));
break;
case proto::DataType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(src, dst, ctx));
break;
default:
PADDLE_THROW("Not support type %d", src_type);
}
}
void TransDataLayout(const std::vector<int>& axis,
const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataLayout only support DataLayout transform on same place!");
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
"TransDataLayout only support Datatype are same!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
auto src_dim = src.dims();
std::vector<int64_t> dst_dim;
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
dst->Resize(make_ddim(dst_dim));
auto place = kernel_pair.second.place_;
dst->mutable_data(place, src.type());
auto src_type = kernel_pair.first.data_type_;
framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
dst->set_layout(kernel_pair.second.data_layout_);
}
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
namespace {
std::vector<int> NHWC2NCHW = {0, 3, 1, 2};
std::vector<int> NCHW2NHWC = {0, 2, 3, 1};
}
REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType);
REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans);
REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans);
REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW,
std::bind(f::TransDataLayout, NHWC2NCHW,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3,
std::placeholders::_4));
REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC,
std::bind(f::TransDataLayout, NCHW2NHWC,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3,
std::placeholders::_4));
......@@ -30,26 +30,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
using DataTransformFn =
std::function<void(const platform::DeviceContext*, const KernelTypePair&,
const Variable&, Variable*)>;
struct KernelTypePairHash {
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
OpKernelType::Hash kernel_type_hasher;
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}
size_t operator()(const KernelTypePair& kernel_pair) const {
std::size_t seed = 0;
HashCombine(kernel_pair.first, &seed);
HashCombine(kernel_pair.second, &seed);
return seed;
}
};
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);
......@@ -57,125 +37,5 @@ Tensor* DataTransform(const OpKernelType& expected_kernel_type,
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);
template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
return static_cast<OutType>(in);
}
};
template <typename InType>
struct CastDataType {
CastDataType(const framework::Tensor& in, framework::Tensor* out,
const platform::DeviceContext* ctx)
: in_(in), out_(out), ctx_(ctx) {}
const framework::Tensor in_;
framework::Tensor* out_;
const platform::DeviceContext* ctx_;
template <typename OutType>
void operator()() {
auto place = ctx_->GetPlace();
auto* in_begin = in_.data<InType>();
auto numel = in_.numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutType>(place);
if (platform::is_cpu_place(place)) {
platform::Transform<platform::CPUDeviceContext> trans;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
} else {
// TODO(dzhwinter): enhance Copy CPU<->GPU with different data type?
PADDLE_THROW("Unsupport CPU <-> GPU!");
}
}
};
struct CastDataLayout {
CastDataLayout(const platform::DeviceContext* ctx,
const std::vector<int>& axis, const framework::Tensor& in,
framework::Tensor* out)
: in_(in), out_(out), ctx_(ctx), axis_(axis) {}
const framework::Tensor in_;
framework::Tensor* out_;
const platform::DeviceContext* ctx_;
const std::vector<int> axis_;
template <typename T>
void operator()() {
auto place = ctx_->GetPlace();
if (platform::is_cpu_place(place)) {
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans4(*context, in_, out_, axis_);
} else {
PADDLE_THROW("Unsupport CPU <-> GPU!");
}
}
};
using DataTransformMap =
std::unordered_map<KernelTypePair, DataTransformFn, KernelTypePairHash>;
class DataTransformFnMap {
public:
static DataTransformFnMap& Instance();
bool Has(const KernelTypePair& key_pair) const {
return map_.find(key_pair) != map_.end();
}
void Insert(const OpKernelType& left, const OpKernelType& right,
const DataTransformFn& data_tranform_fn) {
Insert(std::make_pair(left, right), data_tranform_fn);
}
void Insert(const KernelTypePair& kernel_type_pair,
const DataTransformFn& data_tranform_fn) {
PADDLE_ENFORCE(!Has(kernel_type_pair),
"KernelTypePair %s has been registered", "");
map_.insert({kernel_type_pair, data_tranform_fn});
}
const DataTransformFn& Get(const KernelTypePair& key_pair) const {
auto data_transformer = GetNullable(key_pair);
PADDLE_ENFORCE_NOT_NULL(data_transformer,
"DataTransformFn should not be NULL");
return *data_transformer;
}
const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const {
auto it = map_.find(key_pair);
if (it == map_.end()) {
return nullptr;
} else {
return &(it->second);
}
}
const DataTransformMap& Map() const { return map_; }
private:
DataTransformFnMap() = default;
DataTransformMap map_;
DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
};
// generate unique name with __LINE__
// refs https://stackoverflow.com/questions/1597007
#define TOKENPASTE(x, y) x##y
#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
static int TOKENPASTE2(fn_, __LINE__)() { \
::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
return 0; \
} \
static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
TOKENPASTE2(fn_, __LINE__)()
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <array>
#include <vector>
#include <gtest/gtest.h>
#include "paddle/framework/data_transform.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
using namespace platform;
/**
* @brief cross validation of different kernel type transform
* We use four bit map represent different combination.
* If the field has multiple possible value, only choose two of them.
* For DataType, only test the FP32(float), FP64(double).
* e.g. 0000 -> FP32, CPUPlace, kNHWC, kPlain
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
*/
std::array<proto::DataType, 2> kDataType = {
{proto::DataType::FP32, proto::DataType::FP64}};
std::array<Place, 2> kPlace = {{CPUPlace(), CUDAPlace(0)}};
std::array<DataLayout, 2> kDataLayout = {{
DataLayout::kNHWC, DataLayout::kNCHW,
}};
std::array<LibraryType, 2> kLibraryType = {{
LibraryType::kPlain, LibraryType::kMKLDNN,
}};
OpKernelType GenFromBit(const std::vector<bool> bits) {
return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
kLibraryType[bits[3]]);
}
int test_value = 0;
auto kernel0 = GenFromBit({0, 0, 0, 0});
auto kernel1 = GenFromBit({0, 0, 0, 1});
auto kernel2 = GenFromBit({0, 0, 1, 0});
auto kernel3 = GenFromBit({0, 0, 1, 1});
void TransDataType_t(const platform::DeviceContext* ctx,
const KernelTypePair& p, const Variable& in,
Variable* out) {
test_value++;
}
void TransDataLayout_t(const platform::DeviceContext* ctx,
const KernelTypePair& p, const Variable& in,
Variable* out) {
test_value--;
}
void TransLibraryType_t(const platform::DeviceContext* ctx,
const KernelTypePair& p, const Variable& in,
Variable* out) {
test_value += 2;
}
} // namespace framework
} // namespace paddle
namespace frw = paddle::framework;
REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel1, frw::TransDataType_t);
REGISTER_DATA_TRANSFORM_FN(frw::kernel1, frw::kernel2, frw::TransDataLayout_t);
REGISTER_DATA_TRANSFORM_FN(frw::kernel0, frw::kernel2, frw::TransLibraryType_t);
TEST(DataTransform, Register) {
using namespace paddle::framework;
using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance();
paddle::framework::Variable in;
paddle::framework::Variable out;
DeviceContext* ctx = new CPUDeviceContext();
auto pair0 = std::make_pair(frw::kernel0, frw::kernel1);
instance.Get(pair0)(ctx, pair0, in, &out);
ASSERT_EQ(test_value, 1);
auto pair1 = std::make_pair(frw::kernel1, frw::kernel2);
instance.Get(pair1)(ctx, pair1, in, &out);
ASSERT_EQ(test_value, 0);
auto pair3 = std::make_pair(frw::kernel0, frw::kernel2);
instance.Get(pair3)(ctx, pair3, in, &out);
ASSERT_EQ(test_value, 2);
}
TEST(DataTransform, DataLayout) {
using namespace paddle::framework;
using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance();
Variable in;
Variable out;
Tensor* src = in.GetMutable<Tensor>();
src->mutable_data<double>(make_ddim({2, 3, 1, 2}), CPUPlace());
src->set_layout(DataLayout::kNHWC);
DeviceContext* ctx = new CPUDeviceContext();
{
auto kernel1 = GenFromBit({1, 0, 0, 0});
auto kernel2 = GenFromBit({1, 0, 1, 0});
auto pair0 = std::make_pair(kernel1, kernel2);
instance.Get(pair0)(ctx, pair0, in, &out);
}
Tensor dst = out.Get<Tensor>();
EXPECT_TRUE(dst.layout() == DataLayout::kNCHW);
EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1}));
{
auto kernel1 = GenFromBit({1, 0, 1, 0});
auto kernel2 = GenFromBit({1, 0, 0, 0});
auto pair0 = std::make_pair(kernel1, kernel2);
instance.Get(pair0)(ctx, pair0, out, &in);
}
EXPECT_TRUE(src->layout() == DataLayout::kNHWC);
EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2}));
}
TEST(DataTransform, DataType) {
using namespace paddle::framework;
using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance();
DeviceContext* ctx = new CPUDeviceContext();
Variable in;
Variable out;
Tensor* src = in.GetMutable<Tensor>();
float* ptr = src->mutable_data<float>(make_ddim({2, 3}), CPUPlace());
for (int i = 0; i < 6; ++i) {
ptr[i] = i / 3;
}
{
auto kernel1 = GenFromBit({0, 0, 0, 0});
auto kernel2 = GenFromBit({1, 0, 0, 0});
auto pair0 = std::make_pair(kernel1, kernel2);
instance.Get(pair0)(ctx, pair0, in, &out);
}
Tensor dst = out.Get<Tensor>();
EXPECT_TRUE(dst.data<double>() != nullptr);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_type_transform.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/transform.h"
namespace paddle {
namespace framework {
template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
return static_cast<OutType>(in);
}
};
template <typename InType>
struct CastDataType {
CastDataType(const framework::Tensor& in, framework::Tensor* out,
const platform::DeviceContext* ctx)
: in_(in), out_(out), ctx_(ctx) {}
const framework::Tensor in_;
framework::Tensor* out_;
const platform::DeviceContext* ctx_;
template <typename OutType>
void operator()() {
auto place = ctx_->GetPlace();
auto* in_begin = in_.data<InType>();
auto numel = in_.numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutType>(place);
if (platform::is_cpu_place(place)) {
platform::Transform<platform::CPUDeviceContext> trans;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
} else {
// TODO(dzhwinter): enhance Copy CPU<->GPU with different data type?
PADDLE_THROW("Unsupport CPU <-> GPU!");
}
}
};
void TransDataType(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
auto dims = src.dims();
dst->Resize(dims);
auto dst_type = kernel_pair.second.data_type_;
auto src_type = kernel_pair.first.data_type_;
switch (src_type) {
case proto::DataType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(src, dst, ctx));
break;
case proto::DataType::FP64:
framework::VisitDataType(dst_type, CastDataType<double>(src, dst, ctx));
break;
case proto::DataType::INT32:
framework::VisitDataType(dst_type, CastDataType<int>(src, dst, ctx));
break;
case proto::DataType::INT64:
framework::VisitDataType(dst_type, CastDataType<int64_t>(src, dst, ctx));
break;
case proto::DataType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(src, dst, ctx));
break;
default:
PADDLE_THROW("Not support type %d", src_type);
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/variable.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
void TransDataType(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out);
} // namespace framework
} // namespace paddle
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <algorithm>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/shape_inference.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册