未验证 提交 f62a9291 编写于 作者: C ceci3 提交者: GitHub

fix instance norm (#21042)

* fix instance norm

* update unitest,test=develop
上级 7041eb21
...@@ -328,7 +328,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T> ...@@ -328,7 +328,7 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
epsilon, saved_mean_data, saved_var_data)); epsilon, saved_mean_data, saved_var_data));
} else { } else {
if (d_x) { if (d_x) {
GradComputeDX<T, block><<<grid, block, 0, dev_ctx.stream()>>>( GradComputeDX<T, block><<<NxC, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(), d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D, saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D,
d_x->data<T>()); d_x->data<T>());
......
...@@ -79,7 +79,7 @@ class TestInstanceNormOpTraining(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestInstanceNormOpTraining(unittest.TestCase):
self.init_test_case() self.init_test_case()
def init_test_case(self): def init_test_case(self):
self.use_global_stats = False self.shape = [2, 3, 4, 5]
self.no_grad_set = set() self.no_grad_set = set()
self.fetch_list = [ self.fetch_list = [
'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD', 'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD',
...@@ -181,12 +181,19 @@ class TestInstanceNormOpTraining(unittest.TestCase): ...@@ -181,12 +181,19 @@ class TestInstanceNormOpTraining(unittest.TestCase):
"instance_norm"): "instance_norm"):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
for place in places: for place in places:
test_with_place(place, [2, 3, 4, 5]) test_with_place(place, self.shape)
class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining): class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining):
def init_test_case(self): def init_test_case(self):
self.use_global_stats = False self.shape = [2, 3, 4, 5]
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
class TestInstanceNormOpTrainingCase2(TestInstanceNormOpTraining):
def init_test_case(self):
self.shape = [20, 50, 4, 5]
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD']) self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD'] self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册