You need to sign in or sign up before continuing.
未验证 提交 16999ae4 编写于 作者: Z Zhang Ting 提交者: GitHub

use IndexList to improve performance of instance_norm op (#25132)

* use IndexList to improve performance, test=develop

* remove EIGEN_HAS_INDEX_LIST, test=develop

* use IndexList only when EIGEN_HAS_INDEX_LIST is true
上级 36bb056e
......@@ -181,10 +181,22 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto *place = dev_ctx.eigen_device();
Eigen::DSizes<int, 2> shape(NxC, sample_size);
// Once eigen on Windows is updated, the if branch can be removed.
#ifndef EIGEN_HAS_INDEX_LIST
Eigen::DSizes<int, 2> bcast(1, sample_size);
Eigen::DSizes<int, 2> C_shape(C, 1);
Eigen::DSizes<int, 2> NxC_shape(NxC, 1);
Eigen::DSizes<int, 2> shape(NxC, sample_size);
Eigen::DSizes<int, 1> rdims(1);
#else
Eigen::IndexList<Eigen::type2index<1>, int> bcast;
bcast.set(1, sample_size);
Eigen::IndexList<int, Eigen::type2index<1>> C_shape;
C_shape.set(0, C);
Eigen::IndexList<int, Eigen::type2index<1>> NxC_shape;
NxC_shape.set(0, NxC);
Eigen::IndexList<Eigen::type2index<1>> rdims;
#endif
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
......@@ -201,8 +213,6 @@ class InstanceNormKernel<platform::CPUDeviceContext, T>
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto x_arr = x_e.reshape(shape);
Eigen::DSizes<int, 1> rdims(1);
saved_mean_e.device(*place) = x_arr.mean(rdims);
auto saved_variance_arr =
(x_arr - saved_mean_e.broadcast(bcast)).square().mean(rdims) + epsilon;
......@@ -316,14 +326,25 @@ class InstanceNormGradKernel<platform::CPUDeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto *place = dev_ctx.eigen_device();
Eigen::DSizes<int, 2> rshape(NxC, sample_size);
Eigen::DSizes<int, 2> param_shape(N, C);
Eigen::DSizes<int, 2> shape(NxC, sample_size);
#ifndef EIGEN_HAS_INDEX_LIST
Eigen::DSizes<int, 1> rdims(0);
Eigen::DSizes<int, 1> mean_rdims(1);
Eigen::DSizes<int, 2> rshape(NxC, sample_size);
Eigen::DSizes<int, 2> bcast(1, sample_size);
Eigen::DSizes<int, 2> C_shape(C, 1);
Eigen::DSizes<int, 2> NxC_shape(NxC, 1);
Eigen::DSizes<int, 2> param_shape(N, C);
Eigen::DSizes<int, 2> shape(NxC, sample_size);
#else
Eigen::IndexList<Eigen::type2index<0>> rdims;
Eigen::IndexList<Eigen::type2index<1>> mean_rdims;
Eigen::IndexList<Eigen::type2index<1>, int> bcast;
bcast.set(1, sample_size);
Eigen::IndexList<int, Eigen::type2index<1>> C_shape;
C_shape.set(0, C);
Eigen::IndexList<int, Eigen::type2index<1>> NxC_shape;
NxC_shape.set(0, NxC);
#endif
math::SetConstant<platform::CPUDeviceContext, T> set_constant;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册