未验证 提交 5bfb0ecc 编写于 作者: Y yaoxuefeng 提交者: GitHub

fix datanorm op add attr bug test=develop (#25000) (#26580)

上级 31ad53f8
......@@ -4686,6 +4686,15 @@ def data_norm(input,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
}
attrs = {
"epsilon": epsilon,
"sync_stats": sync_stats,
"summary_decay_rate": summary_decay_rate,
}
if slot_dim > 0:
attrs["slot_dim"] = slot_dim
if enable_scale_and_shift:
attrs["enable_scale_and_shift"] = enable_scale_and_shift
if enable_scale_and_shift:
inputs["scale_w"] = scale_w
inputs["bias"] = bias
......@@ -4700,13 +4709,7 @@ def data_norm(input,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum
},
attrs={
"epsilon": epsilon,
"slot_dim": slot_dim,
"sync_stats": sync_stats,
"summary_decay_rate": summary_decay_rate,
"enable_scale_and_shift": enable_scale_and_shift
})
attrs=attrs)
return helper.append_activation(data_norm_out)
......
......@@ -271,7 +271,7 @@ class TestDataNormOpWithEnableScaleAndShift(OpTest):
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = -1
enable_scale_and_shitf = True
enable_scale_and_shift = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
......@@ -319,6 +319,63 @@ class TestDataNormOpWithEnableScaleAndShift(OpTest):
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithoutEnableScaleAndShift(OpTest):
"""
test class for data norm op
test forward and backward
"""
def setUp(self):
"""
init data norm op test env
"""
self.op_type = 'data_norm'
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = -1
enable_scale_and_shift = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
x_val = np.random.uniform(-1, 1, x_shape).astype(tp)
batch_size = np.ones(scale_shape).astype(tp)
batch_size *= 1e4
batch_sum = np.zeros(scale_shape).astype(tp)
batch_square_sum = np.ones(scale_shape).astype(tp)
batch_square_sum *= 1e4
scale_w = np.ones(scale_shape).astype(tp)
bias = np.zeros(scale_shape).astype(tp)
y = np.array(x_val)
mean = np.zeros(x_shape).astype(tp)
scale = np.ones(x_shape).astype(tp)
self.inputs = {
"X": x_val,
"BatchSize": batch_size,
"BatchSum": batch_sum,
"BatchSquareSum": batch_square_sum,
"scale_w": scale_w,
"bias": bias
}
self.outputs = {"Y": y, "Means": mean, "Scales": scale}
self.attrs = {"epsilon": epsilon, "use_mkldnn": self.use_mkldnn}
def test_check_output(self):
"""
test check forward, check output
"""
self.check_output()
def test_check_grad(self):
"""
test check backward, check grad
"""
self.check_grad(['X'], 'Y', no_grad_set=set([]))
class TestDataNormOpWithEnableScaleAndShift_1(OpTest):
"""
test class for data norm op
......@@ -333,7 +390,7 @@ class TestDataNormOpWithEnableScaleAndShift_1(OpTest):
self.use_mkldnn = False
epsilon = 0.00001
slot_dim = 1
enable_scale_and_shitf = True
enable_scale_and_shift = True
x_shape = [2, 50]
scale_shape = [50]
tp = np.float32
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册