未验证 提交 525a4fda 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #8270 from Yancey1989/improve_concat_split_op

Improve split and concat op
...@@ -314,5 +314,15 @@ DDim stride(const DDim& ddim) { ...@@ -314,5 +314,15 @@ DDim stride(const DDim& ddim) {
} }
return framework::make_ddim(strides); return framework::make_ddim(strides);
} }
DDim stride_numel(const framework::DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return framework::make_ddim(strides);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -125,6 +125,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims); ...@@ -125,6 +125,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src); DDim flatten_to_1d(const DDim& src);
DDim stride(const DDim& ddim); DDim stride(const DDim& ddim);
DDim stride_numel(const DDim& ddim);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = ins.size(); auto place = ctx.GetPlace();
out->mutable_data<T>(place);
auto out_stride = framework::stride_numel(out->dims());
size_t output_offset = 0; size_t output_offset = 0;
out->mutable_data<T>(ctx.GetPlace()); for (auto* in : ins) {
auto out_stride = framework::stride(out->dims()); auto in_stride = framework::stride_numel(in->dims());
for (size_t i = 0; i < n; i++) { StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
auto& in = ins[i]; out->data<T>() + output_offset, out_stride,
auto axis_dim = in->dims()[axis]; in->data<T>(), in_stride);
auto in_stride = framework::stride(in->dims()); output_offset += in_stride[axis];
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
in->dims(), out_stride, out->data<T>() + output_offset);
output_offset += axis_dim * in_stride[axis];
} }
} }
}; };
...@@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
size_t input_offset = 0; size_t input_offset = 0;
auto in_stride = framework::stride(in->dims()); auto in_stride = framework::stride_numel(in->dims());
for (size_t i = 0; i < n; i++) {
auto& out = outs[i]; for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis]; auto out_stride = framework::stride_numel(out->dims());
auto out_stride = framework::stride(out->dims()); StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>()); in_stride);
input_offset += axis_dim * in_stride[axis]; input_offset += out_stride[axis];
} }
} }
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <chrono>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
...@@ -27,18 +28,18 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -27,18 +28,18 @@ class SplitOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride(in->dims()); auto in_stride = framework::stride_numel(in->dims());
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size(); auto place = ctx.GetPlace();
size_t input_offset = 0; size_t input_offset = 0;
for (size_t i = 0; i < n; i++) { for (auto& out : outs) {
auto& out = outs[i];
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis]; auto out_stride = framework::stride_numel(out->dims());
auto out_stride = framework::stride(out->dims()); StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset, out_stride, in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>()); in_stride);
input_offset += axis_dim * in_stride[axis]; input_offset += out_stride[axis];
} }
} }
}; };
......
...@@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src, ...@@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst); StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim); boost::apply_visitor(func, dst_dim);
} }
// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
const framework::DDim& src_stride_numel) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
"src and dst tensor should have the same dims size.");
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis],
dst_stride_numel[i] / dst_stride_numel[axis],
"src and dst should have the same elements "
"except the specified axis.");
} else if (i == axis) {
continue;
} else {
PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i],
"src and dst should have the same elements "
"except the specified axis.");
}
}
for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,11 +20,11 @@ from op_test import OpTest ...@@ -20,11 +20,11 @@ from op_test import OpTest
class TestSplitOp(OpTest): class TestSplitOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "split" self.op_type = "split"
axis = 0 axis = 1
x = np.random.random((4, 2, 5)).astype('float32') x = np.random.random((4, 5, 6)).astype('float32')
out = np.split(x, [1, 3], axis) out = np.split(x, [2, 3], axis)
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'axis': axis, 'sections': [1, 2, 1]} self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
self.outputs = {'Out': [('out%d' % i, out[i]) \ self.outputs = {'Out': [('out%d' % i, out[i]) \
for i in xrange(len(out))]} for i in xrange(len(out))]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册