未验证 提交 91873469 编写于 作者: Y Yuang Liu 提交者: GitHub

Optim fused linear grad add (#55927)

上级 230c6ce1
......@@ -107,7 +107,7 @@
support_dygraph_mode : true
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true, bool has_bias = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
......
......@@ -1470,6 +1470,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
bool has_bias,
MetaTensor* dweight_out,
MetaTensor* dbias_out) {
const auto dtype = dout.dtype();
......@@ -1513,7 +1514,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
? DataType::FLOAT32
: dtype;
if (dbias_out) {
if (has_bias && dbias_out) {
dbias_out->set_dims({weight_dims[1]});
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
}
......
......@@ -299,6 +299,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
bool has_bias,
MetaTensor* dweight_out,
MetaTensor* dbias_out);
......
......@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
int64_t K,
int64_t N,
bool use_addto,
bool has_bias,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value;
......@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
use_addto);
}
if (dbias_out == nullptr) return;
if (!has_bias) return;
if (!fuse_bias_grad) {
auto dout_copy = dout;
......@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
bool has_bias,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
multi_precision = false;
}
if (dbias_out) {
if (has_bias && dbias_out) {
ctx.template Alloc<T>(dbias_out);
}
......@@ -176,6 +178,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
PrintMeta<kLogLevel>(dweight_out, "dweight_out");
PrintMeta<kLogLevel>(dbias_out, "dbias_out");
VLOG(kLogLevel) << "multi_precision = " << multi_precision;
VLOG(kLogLevel) << "has_bias = " << has_bias;
VLOG(kLogLevel) << "use_addto = " << use_addto;
VLOG(kLogLevel) << "M = " << M;
VLOG(kLogLevel) << "N = " << N;
......@@ -183,11 +186,29 @@ void FusedLinearParamGradAdd(const Context &ctx,
}
if (multi_precision) {
FusedLinearParamGradAddImpl<T, MT, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
FusedLinearParamGradAddImpl<T, MT, Context>(ctx,
x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
} else {
FusedLinearParamGradAddImpl<T, T, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
FusedLinearParamGradAddImpl<T, T, Context>(ctx,
x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
}
}
......@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
bool has_bias,
DenseTensor *dweight_out,
DenseTensor *dbias_out) {
PADDLE_THROW(phi::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册