diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index 437496e0aca0af074680b37fddb2088acc73f6cf..55e9cc4a98bd6d36ce5d6bb4116039d0ec18b485 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -205,14 +205,8 @@ class GRUGradKernel : public framework::OpKernel { } if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); - int m = static_cast(batch_gate_grad.dims()[0]); - int n = static_cast(batch_gate_grad.dims()[1]); - Tensor ones; - ones.mutable_data({m}, context.GetPlace()); - math::SetConstant set; - set(dev_ctx, &ones, static_cast(1)); - math::gemv(dev_ctx, true, m, n, 1., batch_gate_grad.data(), - ones.data(), 0., bias_grad->data()); + math::ColwiseSum col_sum; + col_sum(dev_ctx, batch_gate_grad, bias_grad); } } diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 58fedaee9a861ce2d8237ff4b105dfef79017de9..721aa42c92f2926aabbc13d0a9027b2b4e573225 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -341,16 +341,11 @@ class LSTMGradKernel : public framework::OpKernel { } if (bias && bias_g) { /* backward bias */ - int m = static_cast(batch_gate_g.dims()[0]); - int n = static_cast(batch_gate_g.dims()[1]); - - Tensor ones; - ones.mutable_data({m}, ctx.GetPlace()); - math::SetConstant set; - set(device_ctx, &ones, static_cast(1.0)); - - math::gemv(device_ctx, true, m, n, 1., batch_gate_g.data(), - ones.data(), 0., bias_g->data()); + Tensor b_g = *bias_g; + b_g.Resize({bias_g->numel(), 1}); + Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size); + math::ColwiseSum col_sum; + col_sum(device_ctx, batch_gate_g, &gate_bias_g); } if (h0 && h0_g) { diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index a137ffe57f9866d92d081c6aceecaf997c5df032..5ee091788687133f6eaef7229d9f95e2025a2daf 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -308,6 +308,11 @@ void set_constant(const platform::DeviceContext& context, #endif } +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 6daec3797e9b91b67c95a878256275d4ead237de..38c04b97f9d07b9cca938b09f46ea81328a35322 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -292,6 +292,11 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 6b40a08375c21dd82f1284e6dd32c52be0599ee8..ffb99f53808c4316ede96b04e57aec4dae4134de 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -117,6 +117,19 @@ void set_constant_with_place(const platform::DeviceContext& context, void set_constant(const platform::DeviceContext& context, framework::Tensor* tensor, float value); +template +struct RowwiseAdd { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, const framework::Tensor& vec, + framework::Tensor* output); +}; + +template +struct ColwiseSum { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index dba2d02c270a461f4b1a5ae448a2eb2527405e12..4dc17a4e525c52b8f696277274a7ad00a6b00a08 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -43,6 +43,41 @@ void Transpose::operator()( auto* dev = context.GetEigenDevice(); eigen_out.device(*dev) = eigen_in.shuffle(permute); } + +template +void RowwiseAdd::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, + framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(vector); + auto out = framework::EigenMatrix::From(*output); + Eigen::array shape({{1, static_cast(size)}}); + Eigen::array bcast({{static_cast(in_dims[0]), 1}}); + out.device(*context.GetEigenDevice()) = + in + vec.reshape(shape).broadcast(bcast); } + +template +void ColwiseSum::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), size); + + auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + Eigen::array shape({{1, static_cast(size)}}); + vec.reshape(shape).device(*context.GetEigenDevice()) = + in.sum(Eigen::array({{0}})).reshape(shape); } -} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index 5170b595e675aa4f222011c383899e6837182447..5b3bde02fbf981772759caa3d0054fac4a8520f9 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -56,29 +56,6 @@ template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; -template -struct RowwiseAdd { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, const framework::Tensor& bias, - framework::Tensor* output) { - auto in_dims = input.dims(); - auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(bias.numel(), size); - PADDLE_ENFORCE_EQ(output->dims(), in_dims); - - auto in = EigenMatrix::From(input); - auto b = EigenMatrix::From(bias); - auto out = EigenMatrix::From(*output); - Eigen::array bshape({{1, static_cast(size)}}); - Eigen::array bcast({{static_cast(in_dims[0]), 1}}); - out.device(*context.GetEigenDevice()) = - in + b.reshape(bshape).broadcast(bcast); - } -}; - -template struct RowwiseAdd; -template struct RowwiseAdd; - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index e386e63a9a6131af07b1b756ed18636b4ee88716..c5d968aeb216bbb3e0e17f138b9e891494d99f75 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -74,37 +74,6 @@ template class LoDTensor2BatchFunctor; template class Batch2LoDTensorFunctor; template class Batch2LoDTensorFunctor; -template -__global__ void RowwiseAddKernel(const T* src, const T* b, T* dst, - int64_t height, int64_t width) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < height * width; - i += blockDim.x * gridDim.x) { - int64_t h = i / width; - int64_t w = i % width; - dst[h * width + w] = src[h * width + w] + b[w]; - } -} - -template -struct RowwiseAdd { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, const framework::Tensor& bias, - framework::Tensor* output) { - auto in_dims = input.dims(); - auto size = input.numel() / in_dims[0]; - PADDLE_ENFORCE_EQ(bias.numel(), size); - PADDLE_ENFORCE_EQ(output->dims(), in_dims); - int block = 512; - int grid = (input.numel() + block - 1) / block; - auto stream = - reinterpret_cast(context).stream(); - RowwiseAddKernel<<>>( - input.data(), bias.data(), output->data(), in_dims[0], size); - } -}; - -template struct RowwiseAdd; -template struct RowwiseAdd; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 9e7d8630814d887cb8b66423ddeff039fddbc77b..73295ddbcb73fe80be08e732790f0ec75e94b415 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -164,13 +164,6 @@ class Batch2LoDTensorFunctor { } }; -template -struct RowwiseAdd { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, const framework::Tensor& bias, - framework::Tensor* output); -}; - } // namespace math } // namespace operators } // namespace paddle