From fba6a10dd99edf6110280754555af78889f19dd3 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Tue, 2 Jan 2018 21:00:09 +0800 Subject: [PATCH] fix bug in TransDataLayout (#7137) --- paddle/framework/data_transform.cc | 11 ++++++++++- paddle/framework/data_transform_test.cc | 14 +++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/paddle/framework/data_transform.cc b/paddle/framework/data_transform.cc index 58780e3863..9d6a842442 100644 --- a/paddle/framework/data_transform.cc +++ b/paddle/framework/data_transform.cc @@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx, auto* dst = out->GetMutable(); PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); - dst->Resize(src.dims()); + auto src_dim = src.dims(); + dst->Resize(src_dim); auto place = kernel_pair.second.place_; CopyFrom(src, place, *ctx, dst); const std::vector axis = {0, 2, 3, 1}; + std::vector dst_dim; + dst_dim.resize(axis.size()); + for (size_t i = 0; i < axis.size(); i++) { + dst_dim[i] = src_dim[axis[i]]; + } + + dst->Resize(make_ddim(dst_dim)); + auto src_type = kernel_pair.first.data_type_; framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis)); diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 5b01c8434b..8665b6248f 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -32,18 +32,18 @@ using namespace platform; * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN */ -std::array kDataType = {proto::DataType::FP32, - proto::DataType::FP64}; +std::array kDataType = { + {proto::DataType::FP32, proto::DataType::FP64}}; -std::array kPlace = {CPUPlace(), CUDAPlace(0)}; +std::array kPlace = {{CPUPlace(), CUDAPlace(0)}}; -std::array kDataLayout = { +std::array kDataLayout = {{ DataLayout::kNHWC, DataLayout::kNCHW, -}; +}}; -std::array kLibraryType = { +std::array kLibraryType = {{ LibraryType::kPlain, LibraryType::kMKLDNN, -}; +}}; OpKernelType GenFromBit(const std::vector bits) { return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]], -- GitLab