未验证 提交 95c343d3 编写于 作者: H huangxu96 提交者: GitHub

Fix a bug which might occur OOM problem (#40226)

* Add wait after Copy
* fix wrong place delete
上级 60b86b2f
...@@ -95,6 +95,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, ...@@ -95,6 +95,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
norm, norm,
sizeof(T), sizeof(T),
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait();
auto eps = static_cast<T>(1e-5); auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps; *norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
...@@ -102,6 +103,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx, ...@@ -102,6 +103,7 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
std::vector<DenseTensor *> div_outs = {in_grad}; std::vector<DenseTensor *> div_outs = {in_grad};
auto div_functor = DivFunctor<T>(*norm_cpu_ptr); auto div_functor = DivFunctor<T>(*norm_cpu_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor); phi::funcs::ElementwiseKernel<T>(dev_ctx, div_ins, &div_outs, div_functor);
delete norm_tensor; delete norm_tensor;
} }
delete counts_tensor; delete counts_tensor;
......
...@@ -95,6 +95,7 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx, ...@@ -95,6 +95,7 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
norm, norm,
sizeof(T), sizeof(T),
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait();
auto eps = static_cast<T>(1e-5); auto eps = static_cast<T>(1e-5);
*norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps; *norm_cpu_ptr = *norm_cpu_ptr > eps ? *norm_cpu_ptr : eps;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册