提交 4ce39796 编写于 作者: C chengduoZH

fix unit test and c++ code

上级 ae0ea541
......@@ -233,39 +233,37 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product = [](T ele) { return ele * ele; };
auto neg_inv_std = [](T ele) { return -std::sqrt(1 / ele); };
auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto scale_func = [scale_data](T ele) { return ele * scale_data; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
auto inv_std_scale_func = [scale_data](T ele) {
return std::sqrt(1 / ele) * scale_data;
};
auto neg_inv_std_scale_func = [scale_data](T ele) {
return -std::sqrt(1 / ele) * scale_data;
};
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_scale_func)
.replicate(1, right)
.cwiseProduct(d_y_map);
// dy_dmean_dx
auto dmean_end = var_map.unaryExpr(neg_inv_std_scale_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dx_mean = (T(1.0) / right) * dmean_end.replicate(1, right);
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_scale_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx
auto dvar_end_0 = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(neg_inv_std)
.unaryExpr(triple_product)
.cwiseProduct(dvar_end_0);
auto dx_var = (T(1.0) / right) *
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var = (T(-1.0) / right) *
(x_map - mean_map.replicate(1, right))
.cwiseProduct(dvar_end.replicate(1, right));
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
.cwiseProduct(dvar_end)
.unaryExpr(scale_func);
d_x_map = dx_end + dx_mean + dx_var;
}
......
......@@ -52,18 +52,19 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, epsilon):
D = reduce(mul, x_shape, 1) / N
grad_y.shape = [N, D]
x.shape = [N, D]
grad_offset = np.sum(grad_y)
mean.shape = [N, 1]
var.shape = [N, 1]
grad_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y)
d_scale = np.sum(grad_y).reshape([1, ])
d_bias = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y).reshape([1, ])
dx_end = np.sqrt(1.0 / var) * grad_y
d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y, axis=1).reshape([N, 1])
d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape(
[N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) *
np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1])
d_mean = 1.0 / D * (d_mean_0 + d_mean_1)
# d_mean_1 = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape(
# [N, 1]) * (-1.0 / D * np.sqrt(1.0 / var) *
# np.sum(x - mean, axis=1).reshape([N, 1])).reshape([N, 1])
d_mean = 1.0 / D * (d_mean_0)
d_std = np.sum(-1.0 / var * (x - mean) * grad_y, axis=1).reshape([N, 1]) * (
1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean))
......@@ -73,7 +74,7 @@ def _reference_layer_norm_grad(x, grad_y, scale, mean, var, epsilon):
grad_y.shape = x_shape
x.shape = x_shape
return grad_x, grad_scale, grad_offset
return grad_x, d_bias, d_scale
def create_or_get_tensor(scope, var_name, var, place):
......@@ -144,7 +145,7 @@ class TestLayerNormdOp(OpTest):
epsilon = 0.00001
x_shape = shape
scale_shape = [1]
np.random.random(123)
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
......@@ -154,7 +155,6 @@ class TestLayerNormdOp(OpTest):
x_val, scale_val, bias_val, epsilon)
# for gradient test
# y_grad = np.ones(x_shape).astype(np.float32) * 0.00277778
y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_layer_norm_grad(
......@@ -229,7 +229,6 @@ class TestLayerNormdOp(OpTest):
for place in places:
test_with_place(place, [2, 3, 4, 5])
test_with_place(place, [2, 3])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册