diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 63bf3ab6a0382be4764976eedac0ca5314bcd584..3112d0d8205a86326b7ecd9b86ffac486291f6b3 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -166,7 +166,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc_functor matrix_inverse matrix_solve) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 78ea8b6b6fbebd7e0ca5ce14cc2cba6ff197177f..bf7d609370a8d06064db6a2d621be77dc72c188f 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/attention_lstm_op.h" #include -#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" namespace paddle { namespace operators { @@ -377,7 +377,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto& dev_ctx = ctx.template device_context(); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data, atten_b_data); diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h index dfa10e6de72e895aecec55528e7ab8fbfa4fcd5c..6d3b531ce0aa63188f489e5f5179540315352b6d 100644 --- a/paddle/fluid/operators/fc_op.h +++ b/paddle/fluid/operators/fc_op.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/fc.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" namespace paddle { namespace operators { @@ -80,7 +80,7 @@ class FCOpKernel : public framework::OpKernel { T* output_data = output->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; fc(dev_ctx, M, w_dims1, w_dims0, input_data, w_data, output_data, bias ? bias->data() : NULL, with_relu, padding_weights); } diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 3311e3b4ebc9e21d0a033e54ba162e72a80326d0..afbd5380a8301e408ae338cddda5edf3f4916bc8 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -18,8 +18,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/fc.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/sequence2batch.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -298,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel { auto blas = phi::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data, bias ? bias->data() : nullptr); @@ -370,7 +370,7 @@ class FusionGRUKernel : public framework::OpKernel { auto blas = phi::funcs::GetBlas(dev_ctx); phi::funcs::LoDTensor2BatchFunctor to_batch; - math::FCFunctor fc; + phi::funcs::FCFunctor fc; if (M > D3) { fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data, bias ? bias->data() : nullptr); diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 00be8b09d1296018f36c0299f415b7c27f0fad14..3dada660aeffe38d9b4c64d00cc2eaf89653d084 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/fc.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/sequence2batch.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -346,7 +346,7 @@ class FuisonLSTMKernel : public framework::OpKernel { auto blas = phi::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); int xx_offset = D4; @@ -424,7 +424,7 @@ class FuisonLSTMKernel : public framework::OpKernel { phi::funcs::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = phi::funcs::GetBlas(dev_ctx); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; if (M > D4) { fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); to_batch(dev_ctx, *xx, batched_input, true, is_reverse); diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index f71cf1fd43374caa4e43605eeb79888a6497df82..ee28a5480565303bd993b7da02f45b400c70477f 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h" #include // for min, max #include -#include "paddle/fluid/operators/math/fc.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" namespace paddle { namespace operators { @@ -244,7 +244,7 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel { } } auto& dev_ctx = ctx.template device_context(); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data, b_data, true); } diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 1000d0488dc3ffcf6cde977be47ce77d2bc947a7..58613173ad212e31035657d9518a7eaeb6aa7573 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -14,10 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" #include -#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/cpu_vec.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" namespace paddle { namespace operators { @@ -212,7 +212,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { auto blas = phi::funcs::GetBlas(ctx); auto& dev_ctx = ctx.template device_context(); - math::FCFunctor fc; + phi::funcs::FCFunctor fc; fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data, b ? b->data() : NULL); w_data = w_data + M0 * D; diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index c2260c53b2edd09dd69d126bc5e61b995fb20467..e7d697767fcace462b02c133beb2e74ecf84bcb0 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -18,8 +18,8 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/fc.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" #include "paddle/phi/kernels/funcs/sequence2batch.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index df8150b192b6c9c01ef8703d02c9a6384701bdca..913ce07ec673c6d106f2cba3f94e2db0c572eafd 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -36,7 +36,6 @@ if (WITH_ASCEND_CL) else() math_library(beam_search DEPS math_function) endif() -math_library(fc DEPS blas jit_kernel_helper) math_library(matrix_bit_code) diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc deleted file mode 100644 index 4599177fc13aac3f9dc0205658963e9a07cc9a1d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/fc.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/fc.h" - -#include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -template -class FCFunctor { - public: - void operator()(const platform::CPUDeviceContext& context, const int M, - const int N, const int K, const T* X, const T* W, T* Y, - const T* B = nullptr, bool relu = false, - bool padding_weights = false) { - auto blas = phi::funcs::GetBlas(context); - framework::Tensor Y1; - T* Y1_data = nullptr; - if (padding_weights) { - const int NN = N + 4; - const int KK = K + 4; - framework::Tensor X1; - T* X1_data = X1.mutable_data({M * KK}, platform::CPUPlace()); - Y1_data = Y1.mutable_data({M * (N + 4)}, platform::CPUPlace()); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); - } - blas.GEMM(false, false, M, N, K, static_cast(1.0), X1_data, KK, W, NN, - static_cast(0.0), Y1_data, NN); - } else { - blas.MatMul(M, N, K, X, W, Y); - } - if (B == NULL) { - if (padding_weights) { -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); - } - } - PADDLE_ENFORCE_EQ(relu, false, - platform::errors::PermissionDenied( - "When bias is NULL, relu can not be true.")); - return; - } - auto compute = - relu - ? jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(N) - : jit::KernelFuncs, platform::CPUPlace>::Cache() - .At(N); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; - compute(B, src, dst, N); - } - } -}; - -template class FCFunctor; -template class FCFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/fc.cu b/paddle/fluid/operators/math/fc.cu deleted file mode 100644 index 2f94eef34a320e0fdd4fe93e7a6700bf0154c387..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/math/fc.cu +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "paddle/fluid/operators/math/fc.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct FcTypeTraits; - -template <> -struct FcTypeTraits { - typedef float4 Type; -}; - -template <> -struct FcTypeTraits { - typedef double4 Type; -}; - -template -__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < num) { - int bias_idx = tid % K; - const T bias_ptr = bias[bias_idx]; - const T in_ptr = data[tid]; - T packed_val; - packed_val.x = in_ptr.x + bias_ptr.x; - packed_val.y = in_ptr.y + bias_ptr.y; - packed_val.z = in_ptr.z + bias_ptr.z; - packed_val.w = in_ptr.w + bias_ptr.w; - if (DoRelu) { - packed_val.x = fmaxf(0.f, packed_val.x); - packed_val.y = fmaxf(0.f, packed_val.y); - packed_val.z = fmaxf(0.f, packed_val.z); - packed_val.w = fmaxf(0.f, packed_val.w); - } - data[tid] = packed_val; - } -} - -template -__global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) { - int offset = blockIdx.x * N; - - for (int i = threadIdx.x; i < N; i += BlockDim) { - T temp; -#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350 - temp = __ldg(data + offset + i) + __ldg(bias + i); -#else - temp = data[offset + i] + bias[i]; -#endif - if (DoRelu) { - data[offset + i] = static_cast(temp > 0) * temp; - } else { - data[offset + i] = temp; - } - } -} - -template -class FCFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, const int M, - const int N, const int K, const T* X, const T* W, T* Y, - const T* B = nullptr, bool relu = false, - bool padding_weights = false) { - PADDLE_ENFORCE_EQ( - padding_weights, false, - platform::errors::PermissionDenied( - "Weight padding in fc can not be used in GPU scope.")); - auto blas = phi::funcs::GetBlas(context); - blas.GEMM(false, false, M, N, K, static_cast(1.0), X, K, W, N, - static_cast(0.0), Y, N); - if (B == NULL) { - return; - } - - // M * N - if (N % 4 == 0) { - const int threads = 256; - const int num = M * N / 4; - const int blocks = (num + threads - 1) / threads; - typedef typename FcTypeTraits::Type trans_type; - auto* bias_ptr_v4 = reinterpret_cast(B); - auto* data_ptr_v4 = reinterpret_cast(Y); - if (relu) { - bias_relu_v4<<>>( - num, bias_ptr_v4, data_ptr_v4, N / 4); - } else { - bias_relu_v4<<>>( - num, bias_ptr_v4, data_ptr_v4, N / 4); - } - } else { - const int threads = 256; - const int blocks = M; - if (relu) { - InplaceAddReluKernel<<>>( - N, B, Y); - } else { - InplaceAddReluKernel<<>>( - N, B, Y); - } - } - } -}; - -template class FCFunctor; -template class FCFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index b1f010cdff10304407fd9bf7341f6395cc140766..6d16fc8f818957fd42c3988ee7ca0190fa42e86a 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(detail) math_library(deformable_conv_functor DEPS dense_tensor) math_library(concat_and_split_functor DEPS dense_tensor) +math_library(fc_functor DEPS blas jit_kernel_helper) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) math_library(math_function DEPS blas dense_tensor tensor) diff --git a/paddle/phi/kernels/funcs/fc_functor.cc b/paddle/phi/kernels/funcs/fc_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..e14f8522c969a411a63c751ed2c1c0a2896ff0c6 --- /dev/null +++ b/paddle/phi/kernels/funcs/fc_functor.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/fc_functor.h" + +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { +namespace funcs { + +template +void FCFunctor::operator()(const DeviceContext& context, + const int M, + const int N, + const int K, + const T* X, + const T* W, + T* Y, + const T* B, + bool relu, + bool padding_weights) { + auto blas = GetBlas(context); + paddle::framework::Tensor Y1; + T* Y1_data = nullptr; + if (padding_weights) { + const int NN = N + 4; + const int KK = K + 4; + paddle::framework::Tensor X1; + T* X1_data = X1.mutable_data({M * KK}, paddle::platform::CPUPlace()); + Y1_data = Y1.mutable_data({M * (N + 4)}, paddle::platform::CPUPlace()); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); + } + blas.GEMM(false, + false, + M, + N, + K, + static_cast(1.0), + X1_data, + KK, + W, + NN, + static_cast(0.0), + Y1_data, + NN); + } else { + blas.MatMul(M, N, K, X, W, Y); + } + if (B == NULL) { + if (padding_weights) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); + } + } + PADDLE_ENFORCE_EQ( + relu, + false, + errors::PermissionDenied("When bias is NULL, relu can not be true.")); + return; + } + auto compute = relu + ? paddle::operators::jit::KernelFuncs< + paddle::operators::jit::VAddReluTuple, + paddle::platform::CPUPlace>::Cache() + .At(N) + : paddle::operators::jit::KernelFuncs< + paddle::operators::jit::VAddTuple, + paddle::platform::CPUPlace>::Cache() + .At(N); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; + compute(B, src, dst, N); + } +} + +template class FCFunctor; +template class FCFunctor; +template class FCFunctor; +template class FCFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..a26f0edcab2723e0d36230656d2f3519304becb7 --- /dev/null +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -0,0 +1,149 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/fc_functor.h" + +namespace phi { +namespace funcs { + +template +struct FcTypeTraits; + +template <> +struct FcTypeTraits { + typedef float4 Type; +}; + +template <> +struct FcTypeTraits { + typedef double4 Type; +}; + +template +__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int bias_idx = tid % K; + const T bias_ptr = bias[bias_idx]; + const T in_ptr = data[tid]; + T packed_val; + packed_val.x = in_ptr.x + bias_ptr.x; + packed_val.y = in_ptr.y + bias_ptr.y; + packed_val.z = in_ptr.z + bias_ptr.z; + packed_val.w = in_ptr.w + bias_ptr.w; + if (DoRelu) { + packed_val.x = fmaxf(0.f, packed_val.x); + packed_val.y = fmaxf(0.f, packed_val.y); + packed_val.z = fmaxf(0.f, packed_val.z); + packed_val.w = fmaxf(0.f, packed_val.w); + } + data[tid] = packed_val; + } +} + +template +__global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) { + int offset = blockIdx.x * N; + + for (int i = threadIdx.x; i < N; i += BlockDim) { + T temp; +#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350 + temp = __ldg(data + offset + i) + __ldg(bias + i); +#else + temp = data[offset + i] + bias[i]; +#endif + if (DoRelu) { + data[offset + i] = static_cast(temp > 0) * temp; + } else { + data[offset + i] = temp; + } + } +} + +template +void FCFunctor::operator()(const DeviceContext& context, + const int M, + const int N, + const int K, + const T* X, + const T* W, + T* Y, + const T* B, + bool relu, + bool padding_weights) { + PADDLE_ENFORCE_EQ(padding_weights, + false, + errors::PermissionDenied( + "Weight padding in fc can not be used in GPU scope.")); + auto blas = phi::funcs::GetBlas(context); + blas.GEMM(false, + false, + M, + N, + K, + static_cast(1.0), + X, + K, + W, + N, + static_cast(0.0), + Y, + N); + if (B == NULL) { + return; + } + + // M * N + if (N % 4 == 0) { + const int threads = 256; + const int num = M * N / 4; + const int blocks = (num + threads - 1) / threads; + typedef typename FcTypeTraits::Type trans_type; + auto* bias_ptr_v4 = reinterpret_cast(B); + auto* data_ptr_v4 = reinterpret_cast(Y); + if (relu) { + bias_relu_v4<<>>( + num, bias_ptr_v4, data_ptr_v4, N / 4); + } else { + bias_relu_v4<<>>( + num, bias_ptr_v4, data_ptr_v4, N / 4); + } + } else { + const int threads = 256; + const int blocks = M; + if (relu) { + InplaceAddReluKernel<<>>( + N, B, Y); + } else { + InplaceAddReluKernel<<>>( + N, B, Y); + } + } +} + +template class FCFunctor; +template class FCFunctor; + +template class FCFunctor; +template class FCFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/fc.h b/paddle/phi/kernels/funcs/fc_functor.h similarity index 62% rename from paddle/fluid/operators/math/fc.h rename to paddle/phi/kernels/funcs/fc_functor.h index 02f81587c739f2b47ef70a92f01d083c932deae3..3c759acb194b00d67bce524439ce6d25229ec284 100644 --- a/paddle/fluid/operators/math/fc.h +++ b/paddle/phi/kernels/funcs/fc_functor.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,19 +17,23 @@ limitations under the License. */ #include #include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template class FCFunctor { public: - void operator()(const DeviceContext& context, const int M, const int N, - const int K, const T* X, const T* W, T* Y, - const T* B = nullptr, bool relu = false, + void operator()(const DeviceContext& context, + const int M, + const int N, + const int K, + const T* X, + const T* W, + T* Y, + const T* B = nullptr, + bool relu = false, bool weight_pass = false); }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi