未验证 提交 621b6385 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of instance_norm, test=develop (#25005)

上级 971ebb26
......@@ -311,8 +311,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
auto dy_arr = dy_e.reshape(shape);
auto x_arr = x_e.reshape(shape);
auto tmp =
(x_arr - mean_arr.broadcast(bcast)) * inv_var_arr.broadcast(bcast);
auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) *
inv_var_arr.eval().broadcast(bcast);
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
// math: d_bias = np.sum(d_y, axis=(n,h,w))
......@@ -333,7 +333,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
(tmp * dy_arr).sum(mean_rdims).reshape(param_shape).sum(rdims);
}
auto dy_mean = dy_arr.mean(mean_rdims).reshape(NxC_shape).broadcast(bcast);
auto dy_mean =
dy_arr.mean(mean_rdims).reshape(NxC_shape).eval().broadcast(bcast);
Eigen::DSizes<int, 2> bcast_param(N, sample_size);
set_constant(dev_ctx, d_x, static_cast<T>(0));
......@@ -351,6 +352,7 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
(dy_arr * tmp)
.mean(mean_rdims)
.reshape(NxC_shape)
.eval()
.broadcast(bcast));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册