未验证 提交 a7497653 编写于 作者: C chengduo 提交者: GitHub

Refine Split op (#13967)

* speedup split_op
test=develop

* speedup split_op
test=develop

* rename ConcatGrad to Split

* refine concat and split
test=develop

* fix compile error
上级 96e9b658
...@@ -284,10 +284,10 @@ op_library(max_sequence_len_op DEPS lod_rank_table) ...@@ -284,10 +284,10 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project) op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling) op_library(sequence_pool_op DEPS sequence_pooling)
if (NOT WIN32) if (NOT WIN32)
op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
endif(NOT WIN32) endif(NOT WIN32)
op_library(recurrent_op DEPS executor) op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
...@@ -316,7 +316,7 @@ op_library(save_op DEPS lod_tensor) ...@@ -316,7 +316,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat) op_library(concat_op DEPS concat_and_split)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
...@@ -348,6 +348,6 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) ...@@ -348,6 +348,6 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
if(NOT WIN32) if(NOT WIN32)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif() endif()
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/fluid/operators/math/concat.h> #include <paddle/fluid/operators/math/concat_and_split.h>
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
...@@ -89,29 +89,17 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -89,29 +89,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
outputs.push_back(nullptr); outputs.push_back(nullptr);
} }
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Sometimes direct copies will be faster, this maybe need deeply analysis. // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) { if (axis == 0 && outs.size() < 10) {
size_t input_offset = 0; std::vector<const framework::Tensor*> ref_shape;
const auto in_stride = framework::stride_numel(out_grad->dims()); ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end());
StridedMemcpyWithAxis0<T>(dev_ctx, *out_grad, ref_shape, &outputs);
for (size_t i = 0; i < outs.size(); ++i) {
auto out_stride = framework::stride_numel(ins[i]->dims());
auto* out = outputs[i];
if (out != nullptr) {
StridedNumelCopyWithAxis<T>(
ctx.device_context(), axis, out->data<T>(), out_stride,
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
}
input_offset += out_stride[axis];
}
} else { } else {
auto& dev_ctx = ctx.template device_context<DeviceContext>(); math::SplitFunctor<DeviceContext, T> split_functor;
paddle::operators::math::ConcatGradFunctor<DeviceContext, T> split_functor(dev_ctx, *out_grad, ctx.MultiInput<framework::Tensor>("X"),
concat_grad_functor; static_cast<int>(axis), &outputs);
concat_grad_functor(dev_ctx, *out_grad,
ctx.MultiInput<framework::Tensor>("X"),
static_cast<int>(axis), &outputs);
} }
} }
}; };
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -79,7 +79,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> { ...@@ -79,7 +79,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
template <typename DeviceContext> template <typename DeviceContext>
template <typename T> template <typename T>
void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() { void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() {
math::ConcatGradFunctor<DeviceContext, T> func; math::SplitFunctor<DeviceContext, T> func;
func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0, func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0,
&prev_functor_->outputs_); &prev_functor_->outputs_);
} }
......
if (NOT WIN32) if (NOT WIN32)
add_subdirectory(detail) add_subdirectory(detail)
endif(NOT WIN32) endif(NOT WIN32)
function(math_library TARGET) function(math_library TARGET)
...@@ -35,7 +35,7 @@ function(math_library TARGET) ...@@ -35,7 +35,7 @@ function(math_library TARGET)
endfunction() endfunction()
# please add new math_library in alphabetical order # please add new math_library in alphabetical order
math_library(concat) math_library(concat_and_split)
math_library(context_project DEPS im2col math_function) math_library(context_project DEPS im2col math_function)
math_library(cross_entropy) math_library(cross_entropy)
math_library(cos_sim_functor) math_library(cos_sim_functor)
...@@ -43,8 +43,8 @@ math_library(depthwise_conv) ...@@ -43,8 +43,8 @@ math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
if (NOT WIN32) # windows do not support avx functions yet. if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
endif (NOT WIN32) endif (NOT WIN32)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
...@@ -58,7 +58,7 @@ math_library(sequence_pooling DEPS math_function) ...@@ -58,7 +58,7 @@ math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
if (NOT WIN32) if (NOT WIN32)
math_library(matrix_bit_code) math_library(matrix_bit_code)
endif (NOT WIN32) endif (NOT WIN32)
math_library(unpooling) math_library(unpooling)
math_library(vol2col) math_library(vol2col)
...@@ -72,7 +72,7 @@ if(WITH_GPU) ...@@ -72,7 +72,7 @@ if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -67,7 +67,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -67,7 +67,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
* each dimension must be the same, except the axis dimension. * each dimension must be the same, except the axis dimension.
*/ */
template <typename T> template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> { class SplitFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -111,7 +111,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -111,7 +111,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
}; };
#define DEFINE_FUNCTOR(type) \ #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \ template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class ConcatGradFunctor<platform::CPUDeviceContext, type>; template class SplitFunctor<platform::CPUDeviceContext, type>;
FOR_ALL_TYPES(DEFINE_FUNCTOR); FOR_ALL_TYPES(DEFINE_FUNCTOR);
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, __global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols, const int output_rows, const int output_cols,
T* output) { T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -50,7 +50,7 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, ...@@ -50,7 +50,7 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
} }
template <typename T> template <typename T>
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col, __global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
const int out_rows, const int out_cols, const int out_rows, const int out_cols,
T* output_data) { T* output_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -67,9 +67,9 @@ __global__ void KernelConcat(T** inputs_data, const int fixed_in_col, ...@@ -67,9 +67,9 @@ __global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
} }
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int* out_cols, const int in_col, const int* out_cols,
int out_cols_size, T** outputs_data) { int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0; int curr_segment = 0;
int curr_offset = out_cols[0]; int curr_offset = out_cols[0];
...@@ -94,9 +94,9 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, ...@@ -94,9 +94,9 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
} }
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col, const int in_col, const int fixed_out_col,
T** outputs_data) { T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x / fixed_out_col; int split = tid_x / fixed_out_col;
...@@ -170,11 +170,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -170,11 +170,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
dim3 grid_size = dim3(grid_cols, grid_rows, 1); dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) { if (sameShape) {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, in_col, out_row, out_col, output->data<T>()); dev_ins_data, in_col, out_row, out_col, output->data<T>());
} else { } else {
const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace()); const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()), dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
out_row, out_col, output->data<T>()); out_row, out_col, output->data<T>());
} }
...@@ -189,7 +189,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -189,7 +189,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
* each dimension must be the same, except the axis dimension. * each dimension must be the same, except the axis dimension.
*/ */
template <typename T> template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> { class SplitFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -248,11 +248,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -248,11 +248,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
dim3 grid_size = dim3(grid_cols, grid_rows, 1); dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) { if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data); input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
} else { } else {
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data, input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data); static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
} }
...@@ -264,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -264,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
#define DEFINE_FUNCTOR(type) \ #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CUDADeviceContext, type>; \ template class ConcatFunctor<platform::CUDADeviceContext, type>; \
template class ConcatGradFunctor<platform::CUDADeviceContext, type> template class SplitFunctor<platform::CUDADeviceContext, type>
FOR_ALL_TYPES(DEFINE_FUNCTOR); FOR_ALL_TYPES(DEFINE_FUNCTOR);
......
...@@ -54,7 +54,7 @@ class ConcatFunctor { ...@@ -54,7 +54,7 @@ class ConcatFunctor {
* Output[1] = [[5,6]] * Output[1] = [[5,6]]
*/ */
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConcatGradFunctor { class SplitFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
template <typename DeviceContext, typename Place> template <typename DeviceContext, typename Place>
void testConcat() { void testConcat() {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -106,7 +106,7 @@ class SeqConcatGradKernel : public framework::OpKernel<T> { ...@@ -106,7 +106,7 @@ class SeqConcatGradKernel : public framework::OpKernel<T> {
} }
} }
math::ConcatGradFunctor<DeviceContext, T> functor; math::SplitFunctor<DeviceContext, T> functor;
std::vector<const framework::Tensor *> sliced_x_ptr; std::vector<const framework::Tensor *> sliced_x_ptr;
std::vector<framework::Tensor *> sliced_dx_ptr; std::vector<framework::Tensor *> sliced_dx_ptr;
for (auto &x : sliced_x) { for (auto &x : sliced_x) {
......
...@@ -111,11 +111,10 @@ Example: ...@@ -111,11 +111,10 @@ Example:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
USE_CPU_ONLY_OP(concat);
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker); REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker);
REGISTER_OP_CPU_KERNEL(split, REGISTER_OP_CPU_KERNEL(
ops::SplitOpKernel<paddle::platform::CPUPlace, double>, split, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SplitOpKernel<paddle::platform::CPUPlace, float>, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SplitOpKernel<paddle::platform::CPUPlace, int64_t>, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<paddle::platform::CPUPlace, int>); ops::SplitOpKernel<paddle::platform::CPUDeviceContext, int>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
...@@ -28,18 +29,22 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -28,18 +29,22 @@ class SplitOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride_numel(in->dims()); int axis = ctx.Attr<int>("axis");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
size_t input_offset = 0; std::vector<const framework::Tensor*> shape_refer;
for (auto& out : outs) { for (size_t j = 0; j < outs.size(); ++j) {
out->mutable_data<T>(ctx.GetPlace()); outs[j]->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims()); shape_refer.emplace_back(outs[j]);
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), }
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]); auto& dev_ctx = ctx.template device_context<DeviceContext>();
input_offset += out_stride[axis]; // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
StridedMemcpyWithAxis0<T>(dev_ctx, *in, shape_refer, &outs);
} else {
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
} }
} }
}; };
......
...@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/strided_memcpy.h" #include "paddle/fluid/operators/detail/strided_memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -98,5 +99,26 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -98,5 +99,26 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
} }
} }
template <typename T>
inline void StridedMemcpyWithAxis0(
const platform::DeviceContext& dev_ctx, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& shape_refer,
std::vector<framework::Tensor*>* outputs) {
const framework::DDim in_stride = stride_numel(input.dims());
const int axis = 0;
size_t input_offset = 0;
for (size_t i = 0; i < outputs->size(); ++i) {
auto out_stride = stride_numel(shape_refer[i]->dims());
auto out = outputs->at(i);
if (out != nullptr) {
StridedNumelCopyWithAxis<T>(dev_ctx, axis, out->data<T>(), out_stride,
input.data<T>() + input_offset, in_stride,
out_stride[axis]);
}
input_offset += out_stride[axis];
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册