提交 e0e54534 编写于 作者: Y Yancey1989

refine the code

上级 c976fac1
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
......@@ -32,34 +33,13 @@ class ConcatKernel : public framework::OpKernel<T> {
out->mutable_data<T>(place);
auto out_stride = framework::stride_numel(out->dims());
int64_t before = out_stride[0] / out_stride[axis];
int64_t out_after = out_stride[axis];
size_t output_offset = 0;
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
int64_t in_after = in_stride[axis];
for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(
cpu_place, out->data<T>() + output_offset + i * out_after,
cpu_place, in->data<T>() + i * in_after, sizeof(T) * in_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, out->data<T>() +
output_offset + i * out_after,
gpu_place, in->data<T>() + i * in_after,
sizeof(T) * in_after, cuda_ctx.stream()));
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
output_offset += in_after;
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>() + output_offset,
out_stride, in->data<T>(), in_stride);
output_offset += in_stride[axis];
}
}
};
......@@ -73,35 +53,13 @@ class ConcatGradKernel : public framework::OpKernel<T> {
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t input_offset = 0;
auto in_stride = framework::stride_numel(in->dims());
auto place = ctx.GetPlace();
// numel before the specified axis
int64_t before = in_stride[0] / in_stride[axis];
int64_t in_after = in_stride[axis];
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
int64_t out_after = out_stride[axis];
for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, out->data<T>() + i * out_after, cpu_place,
in->data<T>() + input_offset + i * in_after,
sizeof(T) * out_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, out->data<T>() + i * out_after, gpu_place,
in->data<T>() + input_offset + i * in_after,
sizeof(T) * out_after, cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
input_offset += out_after;
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
in->data<T>() + input_offset, in_stride);
input_offset += out_stride[axis];
}
}
};
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
......@@ -26,41 +27,19 @@ template <typename DeviceContext, typename T>
class SplitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// auto start = std::chrono::steady_clock::now();
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride_numel(in->dims());
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace();
// numel before the specified axis
int64_t before = in_stride[0] / in_stride[axis];
int64_t in_after = in_stride[axis];
size_t input_offset = 0;
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims());
int64_t out_after = out_stride[axis];
for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, out->data<T>() + i * out_after, cpu_place,
in->data<T>() + input_offset + i * in_after,
sizeof(T) * out_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, out->data<T>() + i * out_after, gpu_place,
in->data<T>() + input_offset + i * in_after,
sizeof(T) * out_after, cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
input_offset += out_after;
StridedNumelCopyWithAxis<T>(ctx, axis, out->data<T>(), out_stride,
in->data<T>() + input_offset, in_stride);
input_offset += out_stride[axis];
}
}
};
......
......@@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim);
}
// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(const framework::ExecutionContext& ctx,
int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
const framework::DDim& src_stride_numel) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
"src and dst tensor should have the same dims size.");
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis],
dst_stride_numel[i] / dst_stride_numel[axis],
"src and dst should have the same elements "
"except the specified axis.");
} else if (i == axis) {
continue;
} else {
PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i],
"src and dst should have the same elements "
"except the specified axis.");
}
}
for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册