提交 f1c973b0 编写于 作者: Q Qiao Longfei

adam op should not create tmp var in compute

上级 dc8eca82
...@@ -423,6 +423,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -423,6 +423,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
} }
} }
framework::SelectedRows cpu_grad_merge;
const framework::SelectedRows* grad_merge_ptr; const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) { if (is_strict_sorted) {
grad_merge_ptr = &grad; grad_merge_ptr = &grad;
...@@ -430,12 +431,16 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -430,12 +431,16 @@ class AdamOpKernel : public framework::OpKernel<T> {
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor // The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func;
if (platform::is_cpu_place(ctx.GetPlace())) {
grad_merge_ptr = &cpu_grad_merge;
} else {
// FIXME(qiao): GPU also need to fix this
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope()) auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var() .Var()
->GetMutable<framework::SelectedRows>(); ->GetMutable<framework::SelectedRows>();
}
merge_func(ctx.template device_context<DeviceContext>(), grad, merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var, true); grad_merge_ptr, true);
grad_merge_ptr = grad_merge_var;
} }
auto& grad_merge = *grad_merge_ptr; auto& grad_merge = *grad_merge_ptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册