未验证 提交 77da9106 编写于 作者: Y Yuang Liu 提交者: GitHub

fused linear grad add bug fix and perf optim (#56094)

* skip CopyOrAdd when tmp grad is None (#55679)

* Optim fused linear grad add (#55927)
上级 9b317b2d
...@@ -124,7 +124,10 @@ GradNodeAccumulation::operator()( ...@@ -124,7 +124,10 @@ GradNodeAccumulation::operator()(
if (!weak_grad_.expired() && !is_new_grad) { if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock(); auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_); if (grad_out.defined() && grad_out.initialized()) {
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
}
// else { do nothing since there is no valid value in grad out tensor }
is_fake_empty_ = false; is_fake_empty_ = false;
} }
......
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
support_dygraph_mode : true support_dygraph_mode : true
- op : fused_linear_param_grad_add - op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true) args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true, bool has_bias = true)
output : Tensor(dweight_out), Tensor(dbias_out) output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta: infer_meta:
func : FusedLinearParamGradAddInferMeta func : FusedLinearParamGradAddInferMeta
......
...@@ -1259,6 +1259,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -1259,6 +1259,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight, const MetaTensor& dweight,
const MetaTensor& dbias, const MetaTensor& dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
MetaTensor* dweight_out, MetaTensor* dweight_out,
MetaTensor* dbias_out) { MetaTensor* dbias_out) {
const auto dtype = dout.dtype(); const auto dtype = dout.dtype();
...@@ -1302,7 +1303,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -1302,7 +1303,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
? DataType::FLOAT32 ? DataType::FLOAT32
: dtype; : dtype;
if (dbias_out) { if (has_bias && dbias_out) {
dbias_out->set_dims({weight_dims[1]}); dbias_out->set_dims({weight_dims[1]});
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype); dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
} }
......
...@@ -265,6 +265,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -265,6 +265,7 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dweight, const MetaTensor& dweight,
const MetaTensor& dbias, const MetaTensor& dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
MetaTensor* dweight_out, MetaTensor* dweight_out,
MetaTensor* dbias_out); MetaTensor* dbias_out);
......
...@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx, ...@@ -40,6 +40,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
int64_t K, int64_t K,
int64_t N, int64_t N,
bool use_addto, bool use_addto,
bool has_bias,
DenseTensor *dweight_out, DenseTensor *dweight_out,
DenseTensor *dbias_out) { DenseTensor *dbias_out) {
constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value; constexpr bool kIsMultiPrecision = !std::is_same<T, MT>::value;
...@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx, ...@@ -65,7 +66,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
use_addto); use_addto);
} }
if (dbias_out == nullptr) return; if (!has_bias) return;
if (!fuse_bias_grad) { if (!fuse_bias_grad) {
auto dout_copy = dout; auto dout_copy = dout;
...@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -126,6 +127,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight, const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias, const paddle::optional<DenseTensor> &dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
DenseTensor *dweight_out, DenseTensor *dweight_out,
DenseTensor *dbias_out) { DenseTensor *dbias_out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
...@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -159,7 +161,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
multi_precision = false; multi_precision = false;
} }
if (dbias_out) { if (has_bias && dbias_out) {
ctx.template Alloc<T>(dbias_out); ctx.template Alloc<T>(dbias_out);
} }
...@@ -176,6 +178,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -176,6 +178,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
PrintMeta<kLogLevel>(dweight_out, "dweight_out"); PrintMeta<kLogLevel>(dweight_out, "dweight_out");
PrintMeta<kLogLevel>(dbias_out, "dbias_out"); PrintMeta<kLogLevel>(dbias_out, "dbias_out");
VLOG(kLogLevel) << "multi_precision = " << multi_precision; VLOG(kLogLevel) << "multi_precision = " << multi_precision;
VLOG(kLogLevel) << "has_bias = " << has_bias;
VLOG(kLogLevel) << "use_addto = " << use_addto; VLOG(kLogLevel) << "use_addto = " << use_addto;
VLOG(kLogLevel) << "M = " << M; VLOG(kLogLevel) << "M = " << M;
VLOG(kLogLevel) << "N = " << N; VLOG(kLogLevel) << "N = " << N;
...@@ -183,11 +186,29 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -183,11 +186,29 @@ void FusedLinearParamGradAdd(const Context &ctx,
} }
if (multi_precision) { if (multi_precision) {
FusedLinearParamGradAddImpl<T, MT, Context>( FusedLinearParamGradAddImpl<T, MT, Context>(ctx,
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out); x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
} else { } else {
FusedLinearParamGradAddImpl<T, T, Context>( FusedLinearParamGradAddImpl<T, T, Context>(ctx,
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out); x,
dout,
dbias,
M,
K,
N,
use_addto,
has_bias,
dweight_out,
dbias_out);
} }
} }
...@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -199,6 +220,7 @@ void FusedLinearParamGradAdd(const Context &ctx,
const paddle::optional<DenseTensor> &dweight, const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias, const paddle::optional<DenseTensor> &dbias,
bool multi_precision, bool multi_precision,
bool has_bias,
DenseTensor *dweight_out, DenseTensor *dweight_out,
DenseTensor *dbias_out) { DenseTensor *dbias_out) {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册