未验证 提交 2a7bc64c 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #6455 from chengduoZH/refine/conv_zero

Refine conv
...@@ -261,8 +261,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -261,8 +261,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0));
// if is_expand is false, the operation of set_zero is unnecessary,
// because math::matmul will reset input_grad.
if (is_expand) {
set_zero(dev_ctx, input_grad, static_cast<T>(0));
}
math::Col2VolFunctor<DeviceContext, T> col2vol; math::Col2VolFunctor<DeviceContext, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
......
...@@ -225,7 +225,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -225,7 +225,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0));
} }
if (filter_grad) { // filter size (m, c, k_h, k_w) if (filter_grad) { // filter size (m, c, k_h, k_w)
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册