From 3498434bccaf65fdb8cf59ccc07d2a6900d38188 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sat, 10 Feb 2018 10:17:40 +0800 Subject: [PATCH] fix ci --- paddle/operators/concat_op.h | 10 ++++++---- paddle/operators/split_op.h | 5 +++-- paddle/operators/strided_memcpy.h | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h index 2ee9912a3b..ae10b81ea7 100644 --- a/paddle/operators/concat_op.h +++ b/paddle/operators/concat_op.h @@ -37,8 +37,9 @@ class ConcatKernel : public framework::OpKernel { size_t output_offset = 0; for (auto* in : ins) { auto in_stride = framework::stride_numel(in->dims()); - StridedNumelCopyWithAxis(ctx, axis, out->data() + output_offset, - out_stride, in->data(), in_stride); + StridedNumelCopyWithAxis(ctx.device_context(), axis, + out->data() + output_offset, out_stride, + in->data(), in_stride); output_offset += in_stride[axis]; } } @@ -57,8 +58,9 @@ class ConcatGradKernel : public framework::OpKernel { for (auto& out : outs) { out->mutable_data(ctx.GetPlace()); auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx, axis, out->data(), out_stride, - in->data() + input_offset, in_stride); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride); input_offset += out_stride[axis]; } } diff --git a/paddle/operators/split_op.h b/paddle/operators/split_op.h index e239c9cf30..b956808ef9 100644 --- a/paddle/operators/split_op.h +++ b/paddle/operators/split_op.h @@ -37,8 +37,9 @@ class SplitOpKernel : public framework::OpKernel { for (auto& out : outs) { out->mutable_data(ctx.GetPlace()); auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx, axis, out->data(), out_stride, - in->data() + input_offset, in_stride); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride); input_offset += out_stride[axis]; } } diff --git a/paddle/operators/strided_memcpy.h b/paddle/operators/strided_memcpy.h index 49795db91d..ddecfd76dd 100644 --- a/paddle/operators/strided_memcpy.h +++ b/paddle/operators/strided_memcpy.h @@ -50,7 +50,7 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src, // NOTE: The src and dst tensor should have the same elements // except the specified axis. template -inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx, +inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, int64_t axis, T* dst, const framework::DDim& dst_stride_numel, const T* src, @@ -88,7 +88,7 @@ inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx, auto& gpu_place = boost::get(place); auto& cuda_ctx = reinterpret_cast(ctx); - memory::Copy(cpu_place, dst + i * dst_after, cpu_place, + memory::Copy(gpu_place, dst + i * dst_after, gpu_place, src + i * src_after, sizeof(T) * src_after, cuda_ctx.stream()); #else -- GitLab