diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6c95f4b9c5b87dad13959d3d7678a19b79dd96d2..78ef6f207eadea6799864fe22889103b468d1780 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -284,10 +284,10 @@ op_library(max_sequence_len_op DEPS lod_rank_table) op_library(sequence_conv_op DEPS context_project) op_library(sequence_pool_op DEPS sequence_pooling) if (NOT WIN32) -op_library(lstm_op DEPS sequence2batch lstm_compute) -op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) -op_library(lstmp_op DEPS sequence2batch lstm_compute) -op_library(gru_op DEPS sequence2batch gru_compute) + op_library(lstm_op DEPS sequence2batch lstm_compute) + op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) + op_library(lstmp_op DEPS sequence2batch lstm_compute) + op_library(gru_op DEPS sequence2batch gru_compute) endif(NOT WIN32) op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) @@ -316,7 +316,7 @@ op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) -op_library(concat_op DEPS concat) +op_library(concat_op DEPS concat_and_split) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) @@ -348,6 +348,6 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) if(NOT WIN32) -nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) + nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) endif() nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index b8b8b2290a0f002fd379032e28590b84a1da38e9..6257e04b010d8c580e69e466759e8e80d344c105 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#include #include #include "paddle/fluid/framework/lod_rank_table.h" diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index b2c6495c442cd02679825425becc2160c303dcc6..bd474be0facb349c53a8766412311296383a86c5 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { @@ -89,29 +89,17 @@ class ConcatGradKernel : public framework::OpKernel { outputs.push_back(nullptr); } } + auto& dev_ctx = ctx.template device_context(); // Sometimes direct copies will be faster, this maybe need deeply analysis. if (axis == 0 && outs.size() < 10) { - size_t input_offset = 0; - const auto in_stride = framework::stride_numel(out_grad->dims()); - - for (size_t i = 0; i < outs.size(); ++i) { - auto out_stride = framework::stride_numel(ins[i]->dims()); - auto* out = outputs[i]; - if (out != nullptr) { - StridedNumelCopyWithAxis( - ctx.device_context(), axis, out->data(), out_stride, - out_grad->data() + input_offset, in_stride, out_stride[axis]); - } - input_offset += out_stride[axis]; - } + std::vector ref_shape; + ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end()); + StridedMemcpyWithAxis0(dev_ctx, *out_grad, ref_shape, &outputs); } else { - auto& dev_ctx = ctx.template device_context(); - paddle::operators::math::ConcatGradFunctor - concat_grad_functor; - concat_grad_functor(dev_ctx, *out_grad, - ctx.MultiInput("X"), - static_cast(axis), &outputs); + math::SplitFunctor split_functor; + split_functor(dev_ctx, *out_grad, ctx.MultiInput("X"), + static_cast(axis), &outputs); } } }; diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index d7a53f1bef98ecda3ba7b36323678a11a632a15c..339e63a2be13cec7b641b3a9eeb083480fc4b86e 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/gather.h" -#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 8eab83fcd247fcd099ae1fa5dab1e67c2081bf9c..e72337a3e6f7884c3a05372e8732647e5910f3e4 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/port.h" @@ -79,7 +79,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor { template template void LoDTensorToArrayFunctorImpl::apply() { - math::ConcatGradFunctor func; + math::SplitFunctor func; func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0, &prev_functor_->outputs_); } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index c7bdec354735773a15b4c99baf9f7798f2d92564..5d0c0b4228d8e2890c8b8d8bd10e0df080251350 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -1,5 +1,5 @@ if (NOT WIN32) -add_subdirectory(detail) + add_subdirectory(detail) endif(NOT WIN32) function(math_library TARGET) @@ -35,7 +35,7 @@ function(math_library TARGET) endfunction() # please add new math_library in alphabetical order -math_library(concat) +math_library(concat_and_split) math_library(context_project DEPS im2col math_function) math_library(cross_entropy) math_library(cos_sim_functor) @@ -43,8 +43,8 @@ math_library(depthwise_conv) math_library(im2col) if (NOT WIN32) # windows do not support avx functions yet. -math_library(gru_compute DEPS activation_functions math_function) -math_library(lstm_compute DEPS activation_functions) + math_library(gru_compute DEPS activation_functions math_function) + math_library(lstm_compute DEPS activation_functions) endif (NOT WIN32) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) @@ -58,7 +58,7 @@ math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) math_library(softmax DEPS math_function) if (NOT WIN32) -math_library(matrix_bit_code) + math_library(matrix_bit_code) endif (NOT WIN32) math_library(unpooling) math_library(vol2col) @@ -72,7 +72,7 @@ if(WITH_GPU) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) endif() -cc_test(concat_test SRCS concat_test.cc DEPS concat) +cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat_and_split.cc similarity index 95% rename from paddle/fluid/operators/math/concat.cc rename to paddle/fluid/operators/math/concat_and_split.cc index 7b79f10e33d4474e279c6e46208722d6b52277fc..c6e17fd042f19bbeee3507e4cd64f49cff369682 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include namespace paddle { @@ -67,7 +67,7 @@ class ConcatFunctor { * each dimension must be the same, except the axis dimension. */ template -class ConcatGradFunctor { +class SplitFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, @@ -111,7 +111,7 @@ class ConcatGradFunctor { }; #define DEFINE_FUNCTOR(type) \ template class ConcatFunctor; \ - template class ConcatGradFunctor; + template class SplitFunctor; FOR_ALL_TYPES(DEFINE_FUNCTOR); diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat_and_split.cu similarity index 90% rename from paddle/fluid/operators/math/concat.cu rename to paddle/fluid/operators/math/concat_and_split.cu index b59d86e661aff25eba8e770247e85845365d628b..760a065c1081d1e55901774b258ba524471b856b 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" @@ -24,7 +24,7 @@ namespace operators { namespace math { template -__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, +__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; @@ -50,7 +50,7 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, } template -__global__ void KernelConcat(T** inputs_data, const int fixed_in_col, +__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col, const int out_rows, const int out_cols, T* output_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; @@ -67,9 +67,9 @@ __global__ void KernelConcat(T** inputs_data, const int fixed_in_col, } template -__global__ void KernelConcatGrad(const T* input_data, const int in_row, - const int in_col, const int* out_cols, - int out_cols_size, T** outputs_data) { +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int* out_cols, + int out_cols_size, T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int curr_segment = 0; int curr_offset = out_cols[0]; @@ -94,9 +94,9 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, } template -__global__ void KernelConcatGrad(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, - T** outputs_data) { +__global__ void SplitKernel(const T* input_data, const int in_row, + const int in_col, const int fixed_out_col, + T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { int split = tid_x / fixed_out_col; @@ -170,11 +170,11 @@ class ConcatFunctor { dim3 grid_size = dim3(grid_cols, grid_rows, 1); if (sameShape) { - KernelConcat<<>>( + ConcatKernel<<>>( dev_ins_data, in_col, out_row, out_col, output->data()); } else { const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace()); - KernelConcat<<>>( + ConcatKernel<<>>( dev_ins_data, dev_ins_col_data, static_cast(inputs_col.size()), out_row, out_col, output->data()); } @@ -189,7 +189,7 @@ class ConcatFunctor { * each dimension must be the same, except the axis dimension. */ template -class ConcatGradFunctor { +class SplitFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, @@ -248,11 +248,11 @@ class ConcatGradFunctor { dim3 grid_size = dim3(grid_cols, grid_rows, 1); if (sameShape) { - KernelConcatGrad<<>>( + SplitKernel<<>>( input.data(), in_row, in_col, out0_col, dev_out_gpu_data); } else { const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); - KernelConcatGrad<<>>( + SplitKernel<<>>( input.data(), in_row, in_col, dev_outs_col_data, static_cast(outputs_cols.size()), dev_out_gpu_data); } @@ -264,7 +264,7 @@ class ConcatGradFunctor { #define DEFINE_FUNCTOR(type) \ template class ConcatFunctor; \ - template class ConcatGradFunctor + template class SplitFunctor FOR_ALL_TYPES(DEFINE_FUNCTOR); diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat_and_split.h similarity index 98% rename from paddle/fluid/operators/math/concat.h rename to paddle/fluid/operators/math/concat_and_split.h index 867a84fa873a2e90bdab7a5eecbb1755cb4b02d1..3a5eddcbf4af699a89ae1a21571337155699a1f3 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat_and_split.h @@ -54,7 +54,7 @@ class ConcatFunctor { * Output[1] = [[5,6]] */ template -class ConcatGradFunctor { +class SplitFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, const std::vector& ref_inputs, diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc index a46f2d51ca64501a622b5b48b424dffa16efc5b4..8ba9e8e8ec1344edc3beaf7f4a58f99107cc0e9c 100644 --- a/paddle/fluid/operators/math/concat_test.cc +++ b/paddle/fluid/operators/math/concat_test.cc @@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/concat.h" #include #include #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/concat_and_split.h" template void testConcat() { diff --git a/paddle/fluid/operators/sequence_concat_op.h b/paddle/fluid/operators/sequence_concat_op.h index 33e9babff274af888b84d33c991cc0a5b70333ae..ff035f421c4907ba940b973b3fd2a9421ed2dbae 100644 --- a/paddle/fluid/operators/sequence_concat_op.h +++ b/paddle/fluid/operators/sequence_concat_op.h @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/operators/math/concat_and_split.h" namespace paddle { namespace operators { @@ -106,7 +106,7 @@ class SeqConcatGradKernel : public framework::OpKernel { } } - math::ConcatGradFunctor functor; + math::SplitFunctor functor; std::vector sliced_x_ptr; std::vector sliced_dx_ptr; for (auto &x : sliced_x) { diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index d661b276bc31bf0c3ab181d706ffdccec89f0632..a05582ae09e16ee17194d299d713d321f28ccace 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -111,11 +111,10 @@ Example: } // namespace paddle namespace ops = paddle::operators; -USE_CPU_ONLY_OP(concat); REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker); -REGISTER_OP_CPU_KERNEL(split, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel); +REGISTER_OP_CPU_KERNEL( + split, ops::SplitOpKernel, + ops::SplitOpKernel, + ops::SplitOpKernel, + ops::SplitOpKernel); diff --git a/paddle/fluid/operators/split_op.h b/paddle/fluid/operators/split_op.h index f0c417c70521b1bb3816f884d6ab7393473999e4..6f4a25ab5ed86937f2f5db532a9eba22b5a2c5be 100644 --- a/paddle/fluid/operators/split_op.h +++ b/paddle/fluid/operators/split_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include // NOLINT #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { @@ -28,18 +29,22 @@ class SplitOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto outs = ctx.MultiOutput("Out"); - auto in_stride = framework::stride_numel(in->dims()); - int64_t axis = static_cast(ctx.Attr("axis")); + int axis = ctx.Attr("axis"); auto place = ctx.GetPlace(); - size_t input_offset = 0; - for (auto& out : outs) { - out->mutable_data(ctx.GetPlace()); - auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis(ctx.device_context(), axis, out->data(), - out_stride, in->data() + input_offset, - in_stride, out_stride[axis]); - input_offset += out_stride[axis]; + std::vector shape_refer; + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data(ctx.GetPlace()); + shape_refer.emplace_back(outs[j]); + } + + auto& dev_ctx = ctx.template device_context(); + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + StridedMemcpyWithAxis0(dev_ctx, *in, shape_refer, &outs); + } else { + math::SplitFunctor functor; + functor(dev_ctx, *in, shape_refer, axis, &outs); } } }; diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 7a10218e1556698f3e0a1828db5de8851dd1c90b..c3d83a06f23a34ec8cf27d585863135ebfd56a4f 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/detail/strided_memcpy.h" - namespace paddle { namespace operators { @@ -98,5 +99,26 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, } } +template +inline void StridedMemcpyWithAxis0( + const platform::DeviceContext& dev_ctx, const framework::Tensor& input, + const std::vector& shape_refer, + std::vector* outputs) { + const framework::DDim in_stride = stride_numel(input.dims()); + const int axis = 0; + size_t input_offset = 0; + + for (size_t i = 0; i < outputs->size(); ++i) { + auto out_stride = stride_numel(shape_refer[i]->dims()); + auto out = outputs->at(i); + if (out != nullptr) { + StridedNumelCopyWithAxis(dev_ctx, axis, out->data(), out_stride, + input.data() + input_offset, in_stride, + out_stride[axis]); + } + input_offset += out_stride[axis]; + } +} + } // namespace operators } // namespace paddle