未验证 提交 21aa3adc 编写于 作者: 王明冬 提交者: GitHub

move fc_functor from fluid to phi.test=develop (#41856)

上级 f3753b7f
...@@ -166,7 +166,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project ...@@ -166,7 +166,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling executor device_memory_aligment generator) sequence_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc_functor matrix_inverse matrix_solve)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function)
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h" #include "paddle/fluid/operators/attention_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -377,7 +377,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -377,7 +377,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data, fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data,
atten_b_data); atten_b_data);
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/phi/kernels/funcs/fc_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -80,7 +80,7 @@ class FCOpKernel : public framework::OpKernel<T> { ...@@ -80,7 +80,7 @@ class FCOpKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, M, w_dims1, w_dims0, input_data, w_data, output_data, fc(dev_ctx, M, w_dims1, w_dims0, input_data, w_data, output_data,
bias ? bias->data<T>() : NULL, with_relu, padding_weights); bias ? bias->data<T>() : NULL, with_relu, padding_weights);
} }
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -298,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -298,7 +298,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data, fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
bias ? bias->data<T>() : nullptr); bias ? bias->data<T>() : nullptr);
...@@ -370,7 +370,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -370,7 +370,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
if (M > D3) { if (M > D3) {
fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data, fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
bias ? bias->data<T>() : nullptr); bias ? bias->data<T>() : nullptr);
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -346,7 +346,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -346,7 +346,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>()); fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
int xx_offset = D4; int xx_offset = D4;
...@@ -424,7 +424,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -424,7 +424,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; phi::funcs::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
if (M > D4) { if (M > D4) {
fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>()); fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
to_batch(dev_ctx, *xx, batched_input, true, is_reverse); to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
......
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h" #include "paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.h"
#include <algorithm> // for min, max #include <algorithm> // for min, max
#include <string> #include <string>
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -244,7 +244,7 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> { ...@@ -244,7 +244,7 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
} }
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data, fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data,
b_data, true); b_data, true);
} }
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h" #include "paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/cpu_vec.h" #include "paddle/phi/kernels/funcs/cpu_vec.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -212,7 +212,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -212,7 +212,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx); auto blas = phi::funcs::GetBlas<DeviceContext, T>(ctx);
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
math::FCFunctor<DeviceContext, T> fc; phi::funcs::FCFunctor<DeviceContext, T> fc;
fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data, fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data,
b ? b->data<T>() : NULL); b ? b->data<T>() : NULL);
w_data = w_data + M0 * D; w_data = w_data + M0 * D;
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/phi/kernels/funcs/sequence2batch.h" #include "paddle/phi/kernels/funcs/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
......
...@@ -36,7 +36,6 @@ if (WITH_ASCEND_CL) ...@@ -36,7 +36,6 @@ if (WITH_ASCEND_CL)
else() else()
math_library(beam_search DEPS math_function) math_library(beam_search DEPS math_function)
endif() endif()
math_library(fc DEPS blas jit_kernel_helper)
math_library(matrix_bit_code) math_library(matrix_bit_code)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class FCFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = nullptr, bool relu = false,
bool padding_weights = false) {
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
framework::Tensor Y1;
T* Y1_data = nullptr;
if (padding_weights) {
const int NN = N + 4;
const int KK = K + 4;
framework::Tensor X1;
T* X1_data = X1.mutable_data<T>({M * KK}, platform::CPUPlace());
Y1_data = Y1.mutable_data<T>({M * (N + 4)}, platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
}
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X1_data, KK, W, NN,
static_cast<T>(0.0), Y1_data, NN);
} else {
blas.MatMul(M, N, K, X, W, Y);
}
if (B == NULL) {
if (padding_weights) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T));
}
}
PADDLE_ENFORCE_EQ(relu, false,
platform::errors::PermissionDenied(
"When bias is NULL, relu can not be true."));
return;
}
auto compute =
relu
? jit::KernelFuncs<jit::VAddReluTuple<T>,
platform::CPUPlace>::Cache()
.At(N)
: jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache()
.At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
}
};
template class FCFunctor<platform::CPUDeviceContext, float>;
template class FCFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -5,6 +5,7 @@ add_subdirectory(detail) ...@@ -5,6 +5,7 @@ add_subdirectory(detail)
math_library(deformable_conv_functor DEPS dense_tensor) math_library(deformable_conv_functor DEPS dense_tensor)
math_library(concat_and_split_functor DEPS dense_tensor) math_library(concat_and_split_functor DEPS dense_tensor)
math_library(fc_functor DEPS blas jit_kernel_helper)
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
math_library(math_function DEPS blas dense_tensor tensor) math_library(math_function DEPS blas dense_tensor tensor)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/fc_functor.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
namespace funcs {
template <typename DeviceContext, typename T>
void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const int M,
const int N,
const int K,
const T* X,
const T* W,
T* Y,
const T* B,
bool relu,
bool padding_weights) {
auto blas = GetBlas<DeviceContext, T>(context);
paddle::framework::Tensor Y1;
T* Y1_data = nullptr;
if (padding_weights) {
const int NN = N + 4;
const int KK = K + 4;
paddle::framework::Tensor X1;
T* X1_data = X1.mutable_data<T>({M * KK}, paddle::platform::CPUPlace());
Y1_data = Y1.mutable_data<T>({M * (N + 4)}, paddle::platform::CPUPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(X1_data + i * KK, X + i * K, K * sizeof(T));
}
blas.GEMM(false,
false,
M,
N,
K,
static_cast<T>(1.0),
X1_data,
KK,
W,
NN,
static_cast<T>(0.0),
Y1_data,
NN);
} else {
blas.MatMul(M, N, K, X, W, Y);
}
if (B == NULL) {
if (padding_weights) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T));
}
}
PADDLE_ENFORCE_EQ(
relu,
false,
errors::PermissionDenied("When bias is NULL, relu can not be true."));
return;
}
auto compute = relu
? paddle::operators::jit::KernelFuncs<
paddle::operators::jit::VAddReluTuple<T>,
paddle::platform::CPUPlace>::Cache()
.At(N)
: paddle::operators::jit::KernelFuncs<
paddle::operators::jit::VAddTuple<T>,
paddle::platform::CPUPlace>::Cache()
.At(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst;
compute(B, src, dst, N);
}
}
template class FCFunctor<paddle::platform::CPUDeviceContext, float>;
template class FCFunctor<paddle::platform::CPUDeviceContext, double>;
template class FCFunctor<CPUContext, float>;
template class FCFunctor<CPUContext, double>;
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T> template <typename T>
struct FcTypeTraits; struct FcTypeTraits;
...@@ -74,20 +74,35 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) { ...@@ -74,20 +74,35 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
} }
} }
template <typename T> template <typename DeviceContext, typename T>
class FCFunctor<platform::CUDADeviceContext, T> { void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
public: const int M,
void operator()(const platform::CUDADeviceContext& context, const int M, const int N,
const int N, const int K, const T* X, const T* W, T* Y, const int K,
const T* B = nullptr, bool relu = false, const T* X,
bool padding_weights = false) { const T* W,
PADDLE_ENFORCE_EQ( T* Y,
padding_weights, false, const T* B,
platform::errors::PermissionDenied( bool relu,
bool padding_weights) {
PADDLE_ENFORCE_EQ(padding_weights,
false,
errors::PermissionDenied(
"Weight padding in fc can not be used in GPU scope.")); "Weight padding in fc can not be used in GPU scope."));
auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(context); auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N, blas.GEMM(false,
static_cast<T>(0.0), Y, N); false,
M,
N,
K,
static_cast<T>(1.0),
X,
K,
W,
N,
static_cast<T>(0.0),
Y,
N);
if (B == NULL) { if (B == NULL) {
return; return;
} }
...@@ -101,33 +116,34 @@ class FCFunctor<platform::CUDADeviceContext, T> { ...@@ -101,33 +116,34 @@ class FCFunctor<platform::CUDADeviceContext, T> {
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B); auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y); auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) { if (relu) {
bias_relu_v4<trans_type, bias_relu_v4<trans_type, true><<<blocks, threads, 0, context.stream()>>>(
true><<<blocks, threads, 0, context.stream()>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4); num, bias_ptr_v4, data_ptr_v4, N / 4);
} else { } else {
bias_relu_v4<trans_type, bias_relu_v4<trans_type, false><<<blocks, threads, 0, context.stream()>>>(
false><<<blocks, threads, 0, context.stream()>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4); num, bias_ptr_v4, data_ptr_v4, N / 4);
} }
} else { } else {
const int threads = 256; const int threads = 256;
const int blocks = M; const int blocks = M;
if (relu) { if (relu) {
InplaceAddReluKernel<T, true, InplaceAddReluKernel<T,
true,
threads><<<blocks, threads, 0, context.stream()>>>( threads><<<blocks, threads, 0, context.stream()>>>(
N, B, Y); N, B, Y);
} else { } else {
InplaceAddReluKernel<T, false, InplaceAddReluKernel<T,
false,
threads><<<blocks, threads, 0, context.stream()>>>( threads><<<blocks, threads, 0, context.stream()>>>(
N, B, Y); N, B, Y);
} }
} }
} }
};
template class FCFunctor<paddle::platform::CUDADeviceContext, float>;
template class FCFunctor<paddle::platform::CUDADeviceContext, double>;
template class FCFunctor<platform::CUDADeviceContext, float>; template class FCFunctor<GPUContext, float>;
template class FCFunctor<platform::CUDADeviceContext, double>; template class FCFunctor<GPUContext, double>;
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -17,19 +17,23 @@ limitations under the License. */ ...@@ -17,19 +17,23 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FCFunctor { class FCFunctor {
public: public:
void operator()(const DeviceContext& context, const int M, const int N, void operator()(const DeviceContext& context,
const int K, const T* X, const T* W, T* Y, const int M,
const T* B = nullptr, bool relu = false, const int N,
const int K,
const T* X,
const T* W,
T* Y,
const T* B = nullptr,
bool relu = false,
bool weight_pass = false); bool weight_pass = false);
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册