From 621b63855038db3e290866e5f152b79967dafbd6 Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Wed, 10 Jun 2020 20:45:07 +0800 Subject: [PATCH] improve performance of instance_norm, test=develop (#25005) --- paddle/fluid/operators/instance_norm_op.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index a915c018ab9..d2b59a239a2 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -311,8 +311,8 @@ class InstanceNormGradKernel auto dy_arr = dy_e.reshape(shape); auto x_arr = x_e.reshape(shape); - auto tmp = - (x_arr - mean_arr.broadcast(bcast)) * inv_var_arr.broadcast(bcast); + auto tmp = (x_arr - mean_arr.eval().broadcast(bcast)) * + inv_var_arr.eval().broadcast(bcast); math::SetConstant set_constant; // math: d_bias = np.sum(d_y, axis=(n,h,w)) @@ -333,7 +333,8 @@ class InstanceNormGradKernel (tmp * dy_arr).sum(mean_rdims).reshape(param_shape).sum(rdims); } - auto dy_mean = dy_arr.mean(mean_rdims).reshape(NxC_shape).broadcast(bcast); + auto dy_mean = + dy_arr.mean(mean_rdims).reshape(NxC_shape).eval().broadcast(bcast); Eigen::DSizes bcast_param(N, sample_size); set_constant(dev_ctx, d_x, static_cast(0)); @@ -351,6 +352,7 @@ class InstanceNormGradKernel (dy_arr * tmp) .mean(mean_rdims) .reshape(NxC_shape) + .eval() .broadcast(bcast)); } }; -- GitLab