提交 6d73091e 编写于 作者: W wangruting

modify rules

上级 82713eb6
...@@ -33,8 +33,8 @@ def softmax_composite(x, axis): ...@@ -33,8 +33,8 @@ def softmax_composite(x, axis):
max_temp = max(x, axis, keepdim=True) max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True max_temp.stop_gradient = True
molecular = exp(x - max_temp) molecular = exp(x - max_temp)
sqrt_var = sum(molecular, axis=axis, keepdim=True) denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, sqrt_var) res = divide(molecular, denominator)
return res return res
...@@ -105,7 +105,7 @@ def composite_batchnorm( ...@@ -105,7 +105,7 @@ def composite_batchnorm(
@REGISTER_COMPOSITE('layer_norm') @REGISTER_COMPOSITE('layer_norm')
def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis): def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
axis = np.arange(begin_norm_axis, len(x.shape)) axis = tuple(range(begin_norm_axis, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True) mean_ = mean(x, axis=axis, keepdim=True)
difference = x - mean_ difference = x - mean_
var_tmp1 = pow(difference, 2.0) var_tmp1 = pow(difference, 2.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册