未验证 提交 fba6a10d 编写于 作者: Q QI JUN 提交者: GitHub

fix bug in TransDataLayout (#7137)

上级 06888bb0
...@@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx, ...@@ -87,11 +87,20 @@ void TransDataLayout(const platform::DeviceContext* ctx,
auto* dst = out->GetMutable<Tensor>(); auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
dst->Resize(src.dims()); auto src_dim = src.dims();
dst->Resize(src_dim);
auto place = kernel_pair.second.place_; auto place = kernel_pair.second.place_;
CopyFrom(src, place, *ctx, dst); CopyFrom(src, place, *ctx, dst);
const std::vector<int> axis = {0, 2, 3, 1}; const std::vector<int> axis = {0, 2, 3, 1};
std::vector<int64_t> dst_dim;
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
dst->Resize(make_ddim(dst_dim));
auto src_type = kernel_pair.first.data_type_; auto src_type = kernel_pair.first.data_type_;
framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis)); framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis));
......
...@@ -32,18 +32,18 @@ using namespace platform; ...@@ -32,18 +32,18 @@ using namespace platform;
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN * 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
*/ */
std::array<proto::DataType, 2> kDataType = {proto::DataType::FP32, std::array<proto::DataType, 2> kDataType = {
proto::DataType::FP64}; {proto::DataType::FP32, proto::DataType::FP64}};
std::array<Place, 2> kPlace = {CPUPlace(), CUDAPlace(0)}; std::array<Place, 2> kPlace = {{CPUPlace(), CUDAPlace(0)}};
std::array<DataLayout, 2> kDataLayout = { std::array<DataLayout, 2> kDataLayout = {{
DataLayout::kNHWC, DataLayout::kNCHW, DataLayout::kNHWC, DataLayout::kNCHW,
}; }};
std::array<LibraryType, 2> kLibraryType = { std::array<LibraryType, 2> kLibraryType = {{
LibraryType::kPlain, LibraryType::kMKLDNN, LibraryType::kPlain, LibraryType::kMKLDNN,
}; }};
OpKernelType GenFromBit(const std::vector<bool> bits) { OpKernelType GenFromBit(const std::vector<bool> bits) {
return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]], return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册