未验证 提交 f62a9291 编写于 作者: C ceci3 提交者: GitHub

fix instance norm (#21042)

* fix instance norm

* update unitest,test=develop
上级 7041eb21
......@@ -328,7 +328,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
epsilon, saved_mean_data, saved_var_data));
} else {
if (d_x) {
GradComputeDX<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D,
d_x->data<T>());
......
......@@ -79,7 +79,7 @@ class TestInstanceNormOpTraining(unittest.TestCase):
self.init_test_case()
def init_test_case(self):
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
self.no_grad_set = set()
self.fetch_list = [
'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD',
......@@ -181,12 +181,19 @@ class TestInstanceNormOpTraining(unittest.TestCase):
"instance_norm"):
places.append(core.CUDAPlace(0))
for place in places:
test_with_place(place, [2, 3, 4, 5])
test_with_place(place, self.shape)
class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.shape = [2, 3, 4, 5]
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
class TestInstanceNormOpTrainingCase2(TestInstanceNormOpTraining):
def init_test_case(self):
self.shape = [20, 50, 4, 5]
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册