未验证 提交 621b6385 编写于 作者: Z Zhang Ting 提交者: GitHub

improve performance of instance_norm, test=develop (#25005)

上级 971ebb26
...@@ -311,8 +311,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -311,8 +311,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
auto dy_arr = dy_e.reshape(shape); auto dy_arr = dy_e.reshape(shape);
auto x_arr = x_e.reshape(shape); auto x_arr = x_e.reshape(shape);
auto tmp = auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) *
(x_arr - mean_arr.broadcast(bcast)) * inv_var_arr.broadcast(bcast); inv_var_arr.eval().broadcast(bcast);
math::SetConstant<platform::CPUDeviceContext, T> set_constant; math::SetConstant<platform::CPUDeviceContext, T> set_constant;
// math: d_bias = np.sum(d_y, axis=(n,h,w)) // math: d_bias = np.sum(d_y, axis=(n,h,w))
...@@ -333,7 +333,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -333,7 +333,8 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
(tmp * dy_arr).sum(mean_rdims).reshape(param_shape).sum(rdims); (tmp * dy_arr).sum(mean_rdims).reshape(param_shape).sum(rdims);
} }
auto dy_mean = dy_arr.mean(mean_rdims).reshape(NxC_shape).broadcast(bcast); auto dy_mean =
dy_arr.mean(mean_rdims).reshape(NxC_shape).eval().broadcast(bcast);
Eigen::DSizes<int, 2> bcast_param(N, sample_size); Eigen::DSizes<int, 2> bcast_param(N, sample_size);
set_constant(dev_ctx, d_x, static_cast<T>(0)); set_constant(dev_ctx, d_x, static_cast<T>(0));
...@@ -351,6 +352,7 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T> ...@@ -351,6 +352,7 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
(dy_arr * tmp) (dy_arr * tmp)
.mean(mean_rdims) .mean(mean_rdims)
.reshape(NxC_shape) .reshape(NxC_shape)
.eval()
.broadcast(bcast)); .broadcast(bcast));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册