未验证 提交 91873469 编写于 作者: Y Yuang Liu 提交者: GitHub

Optim fused linear grad add (#55927)

上级 230c6ce1
...@@ -107,7 +107,7 @@ ...@@ -107,7 +107,7 @@
support_dygraph_mode : true support_dygraph_mode : true
- op : fused_linear_param_grad_add - op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true) args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true, bool has_bias = true)
output : Tensor(dweight_out), Tensor(dbias_out) output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta: infer_meta:
func : FusedLinearParamGradAddInferMeta func : FusedLinearParamGradAddInferMeta
......
...@@ -1470,6 +1470,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -1470,6 +1470,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight, const MetaTensor& dweight,
const MetaTensor& dbias, const MetaTensor& dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
MetaTensor* dweight_out, MetaTensor* dweight_out,
MetaTensor* dbias_out) { MetaTensor* dbias_out) {
const auto dtype = dout.dtype(); const auto dtype = dout.dtype();
...@@ -1513,7 +1514,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -1513,7 +1514,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
? DataType::FLOAT32 ? DataType::FLOAT32
: dtype; : dtype;
if (dbias_out) { if (has_bias && dbias_out) {
dbias_out->set_dims({weight_dims[1]}); dbias_out->set_dims({weight_dims[1]});
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype); dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
} }
......
...@@ -299,6 +299,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -299,6 +299,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight, const MetaTensor& dweight,
const MetaTensor& dbias, const MetaTensor& dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
MetaTensor* dweight_out, MetaTensor* dweight_out,
MetaTensor* dbias_out); MetaTensor* dbias_out);
......
...@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx, ...@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
int64_t K, int64_t K,
int64_t N, int64_t N,
bool use_addto, bool use_addto,
bool has_bias,
DenseTensor *dweight_out, DenseTensor *dweight_out,
DenseTensor *dbias_out) { DenseTensor *dbias_out) {
constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value; constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value;
...@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx, ...@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
use_addto); use_addto);
} }
if (dbias_out == nullptr) return; if (!has_bias) return;
if (!fuse_bias_grad) { if (!fuse_bias_grad) {
auto dout_copy = dout; auto dout_copy = dout;
...@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight, const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias, const paddle::optional<DenseTensor> &dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
DenseTensor *dweight_out, DenseTensor *dweight_out,
DenseTensor *dbias_out) { DenseTensor *dbias_out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
...@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
multi_precision = false; multi_precision = false;
} }
if (dbias_out) { if (has_bias && dbias_out) {
ctx.template Alloc<T>(dbias_out); ctx.template Alloc<T>(dbias_out);
} }
...@@ -176,6 +178,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -176,6 +178,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
PrintMeta<kLogLevel>(dweight_out, "dweight_out"); PrintMeta<kLogLevel>(dweight_out, "dweight_out");
PrintMeta<kLogLevel>(dbias_out, "dbias_out"); PrintMeta<kLogLevel>(dbias_out, "dbias_out");
VLOG(kLogLevel) << "multi_precision = " << multi_precision; VLOG(kLogLevel) << "multi_precision = " << multi_precision;
VLOG(kLogLevel) << "has_bias = " << has_bias;
VLOG(kLogLevel) << "use_addto = " << use_addto; VLOG(kLogLevel) << "use_addto = " << use_addto;
VLOG(kLogLevel) << "M = " << M; VLOG(kLogLevel) << "M = " << M;
VLOG(kLogLevel) << "N = " << N; VLOG(kLogLevel) << "N = " << N;
...@@ -183,11 +186,29 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -183,11 +186,29 @@ void FusedLinearParamGradAdd(const Context &ctx,
} }
if (multi_precision) { if (multi_precision) {
FusedLinearParamGradAddImpl<T, MT, Context>( FusedLinearParamGradAddImpl<T, MT, Context>(ctx,
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out); x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
} else { } else {
FusedLinearParamGradAddImpl<T, T, Context>( FusedLinearParamGradAddImpl<T, T, Context>(ctx,
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out); x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
} }
} }
...@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight, const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias, const paddle::optional<DenseTensor> &dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
DenseTensor *dweight_out, DenseTensor *dweight_out,
DenseTensor *dbias_out) { DenseTensor *dbias_out) {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册