From 9a05c9075043345e34b4461ded2ce92ba6501ae4 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 12 Feb 2018 10:38:31 +0800 Subject: [PATCH] fix StridedNumelCopyWithAxis --- paddle/fluid/operators/strided_memcpy.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 385124305e..4036d1091d 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -58,6 +58,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; int64_t src_after = src_stride_numel[axis]; int64_t dst_after = dst_stride_numel[axis]; + int64_t copy_size = std::min(src_after, dst_after); auto place = ctx.GetPlace(); PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(), @@ -82,14 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, if (platform::is_cpu_place(place)) { auto& cpu_place = boost::get(place); memory::Copy(cpu_place, dst + i * dst_after, cpu_place, - src + i * src_after, sizeof(T) * src_after); + src + i * src_after, sizeof(T) * copy_size); } else { #ifdef PADDLE_WITH_CUDA auto& gpu_place = boost::get(place); auto& cuda_ctx = reinterpret_cast(ctx); memory::Copy(gpu_place, dst + i * dst_after, gpu_place, - src + i * src_after, sizeof(T) * src_after, + src + i * src_after, sizeof(T) * copy_size, cuda_ctx.stream()); #else PADDLE_THROW("Paddle is not compiled with GPU"); -- GitLab