提交 3498434b 编写于 作者: Y Yancey1989

fix ci

上级 31f598fc
......@@ -37,8 +37,9 @@ class ConcatKernel : public framework::OpKernel<T> {
size_t output_offset = 0;
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>() + output_offset,
out_stride, in->data<T>(), in_stride);
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride);
output_offset += in_stride[axis];
}
}
......@@ -57,8 +58,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
in->data<T>() + input_offset, in_stride);
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride);
input_offset += out_stride[axis];
}
}
......
......@@ -37,8 +37,9 @@ class SplitOpKernel : public framework::OpKernel<T> {
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
in->data<T>() + input_offset, in_stride);
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride);
input_offset += out_stride[axis];
}
}
......
......@@ -50,7 +50,7 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx,
inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
......@@ -88,7 +88,7 @@ inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx,
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after,
cuda_ctx.stream());
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册