未验证 提交 25b4ba7f 编写于 作者: S Sonder 提交者: GitHub

Move fused feedforward (#53166)

* trans fused_feedward Compute function to phi

* add register info

* remove maxfunctor

* move fused feedward to phi

* remove sig file

* remove fliud include

* add include

* add include

* add sig file

* add output register info

* fix sig file

* Update fused_feedforward_sig.cc

* fix grad kernel

* update output register info

* fix

* open fused_feedforward static build

* add optional and fix code style

* fix output info for fused attention

* add optional param

* merge
上级 18e9dcdc
...@@ -813,6 +813,8 @@ PD_REGISTER_KERNEL(fused_attention, ...@@ -813,6 +813,8 @@ PD_REGISTER_KERNEL(fused_attention,
phi::dtype::float16, phi::dtype::float16,
double, double,
float) { float) {
kernel->OutputAt(9).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(14).SetDataType(phi::DataType::UINT8);
if (kernel_key.dtype() == phi::DataType::FLOAT16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
......
...@@ -25,11 +25,6 @@ struct MulGradFunctor { ...@@ -25,11 +25,6 @@ struct MulGradFunctor {
inline HOSTDEVICE T Dy(T x, T y) { return x; } inline HOSTDEVICE T Dy(T x, T y) { return x; }
}; };
template <typename T>
struct MaxFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; }
};
template <typename T> template <typename T>
struct AddGradFunctor { struct AddGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); } inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); }
......
...@@ -24,13 +24,13 @@ void FusedFeedForwardGradKernel( ...@@ -24,13 +24,13 @@ void FusedFeedForwardGradKernel(
const DenseTensor& out_grad, const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& linear1_weight, const DenseTensor& linear1_weight,
const DenseTensor& linear1_bias, const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight, const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask, const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask, const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out, const DenseTensor& linear1_out,
const DenseTensor& dropout1_out, const DenseTensor& dropout1_out,
const DenseTensor& dropout2_out, const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale, const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias, const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out, const paddle::optional<DenseTensor>& ln1_out,
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h" #include "paddle/phi/kernels/fusion/gpu/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/layer_norm_kernel.h" #include "paddle/phi/kernels/layer_norm_kernel.h"
PHI_DECLARE_bool(use_fast_math);
namespace phi { namespace phi {
namespace fusion { namespace fusion {
...@@ -292,21 +294,22 @@ class FusedDropoutHelper { ...@@ -292,21 +294,22 @@ class FusedDropoutHelper {
T* d_bias, T* d_bias,
const std::string& act_method) { const std::string& act_method) {
if (act_method == "gelu") { if (act_method == "gelu") {
phi::funcs::GeluGradFunctor<T> gelu_grad; phi::fusion::GeluGradFunctor<T> gelu_grad;
phi::fusion:: phi::fusion::LaunchDropoutActBiasGrad<T,
LaunchDropoutActBiasGrad<T, MaskType, phi::funcs::GeluGradFunctor<T>>( MaskType,
gelu_grad, phi::fusion::GeluGradFunctor<T>>(
dout, gelu_grad,
mask, dout,
src, mask,
bias, src,
dropout_param_.dropout_prob, bias,
dropout_param_.is_upscale_in_train, dropout_param_.dropout_prob,
rows_, dropout_param_.is_upscale_in_train,
cols_, rows_,
d_src, cols_,
d_bias, d_src,
ctx); d_bias,
ctx);
} else if (act_method == "relu") { } else if (act_method == "relu") {
phi::funcs::ReluGradFunctor<T> relu_grad; phi::funcs::ReluGradFunctor<T> relu_grad;
phi::fusion:: phi::fusion::
......
...@@ -366,13 +366,13 @@ void FusedFeedForwardGradKernel( ...@@ -366,13 +366,13 @@ void FusedFeedForwardGradKernel(
const DenseTensor& out_grad, const DenseTensor& out_grad,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& linear1_weight, const DenseTensor& linear1_weight,
const DenseTensor& linear1_bias, const paddle::optional<DenseTensor>& linear1_bias,
const DenseTensor& linear2_weight, const DenseTensor& linear2_weight,
const DenseTensor& dropout1_mask, const DenseTensor& dropout1_mask,
const DenseTensor& dropout2_mask, const DenseTensor& dropout2_mask,
const DenseTensor& linear1_out, const DenseTensor& linear1_out,
const DenseTensor& dropout1_out, const DenseTensor& dropout1_out,
const DenseTensor& dropout2_out, const paddle::optional<DenseTensor>& dropout2_out,
const paddle::optional<DenseTensor>& ln1_scale, const paddle::optional<DenseTensor>& ln1_scale,
const paddle::optional<DenseTensor>& ln1_bias, const paddle::optional<DenseTensor>& ln1_bias,
const paddle::optional<DenseTensor>& ln1_out, const paddle::optional<DenseTensor>& ln1_out,
...@@ -417,7 +417,7 @@ void FusedFeedForwardGradKernel( ...@@ -417,7 +417,7 @@ void FusedFeedForwardGradKernel(
auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr; auto* ln1_out_ptr = pre_layer_norm ? ln1_out.get_ptr() : nullptr;
auto* dropout1_out_ptr = &dropout1_out; auto* dropout1_out_ptr = &dropout1_out;
auto* dropout2_out_ptr = &dropout2_out; auto* dropout2_out_ptr = dropout2_out.get_ptr();
auto* linear1_weight_ptr = &linear1_weight; auto* linear1_weight_ptr = &linear1_weight;
auto* linear2_weight_ptr = &linear2_weight; auto* linear2_weight_ptr = &linear2_weight;
......
...@@ -1162,6 +1162,8 @@ set(STATIC_BUILD_TESTS ...@@ -1162,6 +1162,8 @@ set(STATIC_BUILD_TESTS
test_fetch_lod_tensor_array test_fetch_lod_tensor_array
test_fused_attention_op test_fused_attention_op
test_fused_attention_op_api test_fused_attention_op_api
test_fused_feedforward_op
test_fused_feedforward_pass
test_imperative_optimizer test_imperative_optimizer
test_lamb_op test_lamb_op
test_layer_norm_op test_layer_norm_op
...@@ -1191,6 +1193,8 @@ set(STATIC_BUILD_TESTS ...@@ -1191,6 +1193,8 @@ set(STATIC_BUILD_TESTS
if(NOT WITH_GPU) if(NOT WITH_GPU)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op) list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api) list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op_api)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op)
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_feedforward_op_pass)
endif() endif()
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS}) foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册