diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index f063ee2e6dd81e66b7b74aa23e9967d865c4d297..98f1fa9d209eb23627949a4ccb235566652076a5 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -314,5 +314,15 @@ DDim stride(const DDim& ddim) { } return framework::make_ddim(strides); } + +DDim stride_numel(const framework::DDim& ddim) { + std::vector strides(ddim.size()); + strides[ddim.size() - 1] = ddim[ddim.size() - 1]; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i]; + } + return framework::make_ddim(strides); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 750ab787abb72fa3f2984caf58354e327750aa3d..405d9af7028d74e52ed015a269770c9b4ae23a5f 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -125,6 +125,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); DDim stride(const DDim& ddim); + +DDim stride_numel(const DDim& ddim); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 72b3e225bf64f889804eb5e4fab9df4653f5452b..878e53058567500aeb9fe854a1a65ed5380572a8 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto* out = ctx.Output("Out"); int64_t axis = static_cast(ctx.Attr("axis")); - const size_t n = ins.size(); + auto place = ctx.GetPlace(); + out->mutable_data(place); + + auto out_stride = framework::stride_numel(out->dims()); + size_t output_offset = 0; - out->mutable_data(ctx.GetPlace()); - auto out_stride = framework::stride(out->dims()); - for (size_t i = 0; i < n; i++) { - auto& in = ins[i]; - auto axis_dim = in->dims()[axis]; - auto in_stride = framework::stride(in->dims()); - StridedMemcpy(ctx.device_context(), in->data(), in_stride, - in->dims(), out_stride, out->data() + output_offset); - output_offset += axis_dim * in_stride[axis]; + for (auto* in : ins) { + auto in_stride = framework::stride_numel(in->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, + out->data() + output_offset, out_stride, + in->data(), in_stride); + output_offset += in_stride[axis]; } } }; @@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel { auto* in = ctx.Input(framework::GradVarName("Out")); auto outs = ctx.MultiOutput(framework::GradVarName("X")); int64_t axis = static_cast(ctx.Attr("axis")); - const size_t n = outs.size(); size_t input_offset = 0; - auto in_stride = framework::stride(in->dims()); - for (size_t i = 0; i < n; i++) { - auto& out = outs[i]; + auto in_stride = framework::stride_numel(in->dims()); + + for (auto& out : outs) { out->mutable_data(ctx.GetPlace()); - size_t axis_dim = out->dims()[axis]; - auto out_stride = framework::stride(out->dims()); - StridedMemcpy(ctx.device_context(), in->data() + input_offset, - in_stride, out->dims(), out_stride, out->data()); - input_offset += axis_dim * in_stride[axis]; + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride); + input_offset += out_stride[axis]; } } }; diff --git a/paddle/fluid/operators/split_op.h b/paddle/fluid/operators/split_op.h index e78218f2fb108dac5bce717e03ce0aba1ed88195..06bcf82620bec57346c30b029d23ad8417252248 100644 --- a/paddle/fluid/operators/split_op.h +++ b/paddle/fluid/operators/split_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/strided_memcpy.h" @@ -27,18 +28,18 @@ class SplitOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto outs = ctx.MultiOutput("Out"); - auto in_stride = framework::stride(in->dims()); + auto in_stride = framework::stride_numel(in->dims()); int64_t axis = static_cast(ctx.Attr("axis")); - const size_t n = outs.size(); + auto place = ctx.GetPlace(); + size_t input_offset = 0; - for (size_t i = 0; i < n; i++) { - auto& out = outs[i]; + for (auto& out : outs) { out->mutable_data(ctx.GetPlace()); - size_t axis_dim = out->dims()[axis]; - auto out_stride = framework::stride(out->dims()); - StridedMemcpy(ctx.device_context(), in->data() + input_offset, - in_stride, out->dims(), out_stride, out->data()); - input_offset += axis_dim * in_stride[axis]; + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), + out_stride, in->data() + input_offset, + in_stride); + input_offset += out_stride[axis]; } } }; diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 8a99b405e266da48427fa23e9a3e67f2bc54c5a0..385124305e2d9afd62313ca46178b4916cd6405d 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src, StridedCopyDimVisitor func(dev_ctx, src, src_stride, dst_stride, dst); boost::apply_visitor(func, dst_dim); } + +// Strided numel memory copy from src to dst by the specified axis +// +// For example, for a tensor dims [4, 20, 100], the strieded numel is +// [8000, 2000, 100] +// +// NOTE: The src and dst tensor should have the same elements +// except the specified axis. +template +inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, + int64_t axis, T* dst, + const framework::DDim& dst_stride_numel, + const T* src, + const framework::DDim& src_stride_numel) { + int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; + int64_t src_after = src_stride_numel[axis]; + int64_t dst_after = dst_stride_numel[axis]; + auto place = ctx.GetPlace(); + + PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(), + "src and dst tensor should have the same dims size."); + + for (int64_t i = 0; i < axis; ++i) { + if (i < axis) { + PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis], + dst_stride_numel[i] / dst_stride_numel[axis], + "src and dst should have the same elements " + "except the specified axis."); + } else if (i == axis) { + continue; + } else { + PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i], + "src and dst should have the same elements " + "except the specified axis."); + } + } + + for (int64_t i = 0; i < before; ++i) { + if (platform::is_cpu_place(place)) { + auto& cpu_place = boost::get(place); + memory::Copy(cpu_place, dst + i * dst_after, cpu_place, + src + i * src_after, sizeof(T) * src_after); + } else { +#ifdef PADDLE_WITH_CUDA + auto& gpu_place = boost::get(place); + auto& cuda_ctx = + reinterpret_cast(ctx); + memory::Copy(gpu_place, dst + i * dst_after, gpu_place, + src + i * src_after, sizeof(T) * src_after, + cuda_ctx.stream()); +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif + } + } +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_split_op.py b/python/paddle/v2/fluid/tests/test_split_op.py index b80b64c41be4a45e9d3725297a526e93b399f00d..50347d2df48e3f163054d627fd4624c330690c67 100644 --- a/python/paddle/v2/fluid/tests/test_split_op.py +++ b/python/paddle/v2/fluid/tests/test_split_op.py @@ -20,11 +20,11 @@ from op_test import OpTest class TestSplitOp(OpTest): def setUp(self): self.op_type = "split" - axis = 0 - x = np.random.random((4, 2, 5)).astype('float32') - out = np.split(x, [1, 3], axis) + axis = 1 + x = np.random.random((4, 5, 6)).astype('float32') + out = np.split(x, [2, 3], axis) self.inputs = {'X': x} - self.attrs = {'axis': axis, 'sections': [1, 2, 1]} + self.attrs = {'axis': axis, 'sections': [2, 1, 2]} self.outputs = {'Out': [('out%d' % i, out[i]) \ for i in xrange(len(out))]}