未验证 提交 25b4ba7f 编写于 作者: S Sonder 提交者: GitHub

Move fused feedforward (#53166)

* trans fused_feedward Compute function to phi

* add register info

* remove maxfunctor

* move fused feedward to phi

* remove sig file

* remove fliud include

* add include

* add include

* add sig file

* add output register info

* fix sig file

* Update fused_feedforward_sig.cc

* fix grad kernel

* update output register info

* fix

* open fused_feedforward static build

* add optional and fix code style

* fix output info for fused attention

* add optional param

* merge
上级 18e9dcdc
......@@ -813,6 +813,8 @@ PD_REGISTER_KERNEL(fused_attention,
phi::dtype::float16,
double,
float) {
kernel->OutputAt(9).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(14).SetDataType(phi::DataType::UINT8);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
......
......@@ -25,11 +25,6 @@ struct MulGradFunctor {
inline HOSTDEVICE T Dy(T x, T y) { return x; }
};
template <typename T>
struct MaxFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; }
};
template <typename T>
struct AddGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); }
......
......@@ -24,13 +24,13 @@ void FusedFeedForwardGradKernel(
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& linear1_weight,
const DenseTensor& linear1_bias,
const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out,
const DenseTensor& dropout1_out,
const DenseTensor& dropout2_out,
const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out,
......
......@@ -30,6 +30,8 @@
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
PHI_DECLARE_bool(use_fast_math);
namespace phi {
namespace fusion {
......@@ -292,21 +294,22 @@ class FusedDropoutHelper {
T* d_bias,
const std::string& act_method) {
if (act_method == "gelu") {
phi::funcs::GeluGradFunctor<T> gelu_grad;
phi::fusion::
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::GeluGradFunctor<T>>(
gelu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
phi::fusion::GeluGradFunctor<T> gelu_grad;
phi::fusion::LaunchDropoutActBiasGrad<T,
MaskType,
phi::fusion::GeluGradFunctor<T>>(
gelu_grad,
dout,
mask,
src,
bias,
dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train,
rows_,
cols_,
d_src,
d_bias,
ctx);
} else if (act_method == "relu") {
phi::funcs::ReluGradFunctor<T> relu_grad;
phi::fusion::
......
......@@ -366,13 +366,13 @@ void FusedFeedForwardGradKernel(
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& linear1_weight,
const DenseTensor& linear1_bias,
const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out,
const DenseTensor& dropout1_out,
const DenseTensor& dropout2_out,
const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out,
......@@ -417,7 +417,7 @@ void FusedFeedForwardGradKernel(
auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr;
auto* dropout1_out_ptr = &dropout1_out;
auto* dropout2_out_ptr = &dropout2_out;
auto* dropout2_out_ptr = dropout2_out.get_ptr();
auto* linear1_weight_ptr = &linear1_weight;
auto* linear2_weight_ptr = &linear2_weight;
......
......@@ -1162,6 +1162,8 @@ set(STATIC_BUILD_TESTS
test_fetch_lod_tensor_array
test_fused_attention_op
test_fused_attention_op_api
test_fused_feedforward_op
test_fused_feedforward_pass
test_imperative_optimizer
test_lamb_op
test_layer_norm_op
......@@ -1191,6 +1193,8 @@ set(STATIC_BUILD_TESTS
if(NOT WITH_GPU)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op_pass)
endif()
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册