提交 0f47703d 编写于 作者: C chengduoZH

add begin_norm_axis

上级 4ce39796
......@@ -42,10 +42,17 @@ class LayerNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], 1);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], 1);
auto x_dim = ctx->GetInputDim("X");
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
"'begin_norm_axis' must be less than the rank of X");
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {ctx->GetInputDim("X")[0]});
ctx->SetOutputDim("Variance", {ctx->GetInputDim("X")[0]});
ctx->SetOutputDim("Mean", {left});
ctx->SetOutputDim("Variance", {left});
ctx->ShareLoD("X", "Y");
}
......@@ -72,10 +79,14 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
});
AddAttr<std::vector<int>>("axis",
"(vector<int> default:{1, 1, 1}), the "
"axis to normalize.")
.SetDefault({1, 2, 3}); // todo(zcd) : who to set axis
AddAttr<int>("begin_norm_axis",
"(int default:1), the "
"axis of `begin_norm_axis ... Rank(X) - 1` will be normalized")
.SetDefault(1)
.AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
"'begin_norm_axis' should be greater than zero.");
});
AddComment(R"DOC(
Layer Normalization.
......@@ -97,9 +108,7 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
const int N = x_dims[0];
const int sample_size = x->numel() / N;
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto scale_data = scale->data<T>()[0];
auto bias_data = bias->data<T>()[0];
......@@ -111,7 +120,9 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
int left = N, right = sample_size;
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
......@@ -131,7 +142,8 @@ class LayerNormKernel<platform::CPUDeviceContext, T>
return std::sqrt(1 / ele) * scale_data;
};
auto sub_bias = [bias_data](T ele) { return bias_data - ele; };
// TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory.
output_map = (var_map.unaryExpr(scale_inv_std).replicate(1, right))
.cwiseProduct(input_map) +
var_map.unaryExpr(scale_inv_std)
......@@ -198,13 +210,14 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
const auto *var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto scale_data = scale->data<T>()[0];
const auto &x_dims = x->dims();
const int N = x_dims[0];
const int sample_size = x->numel() / N;
int left = N, right = sample_size;
auto scale_data = scale->data<T>()[0];
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]),
right = static_cast<int>(matrix_dim[1]);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -223,11 +236,13 @@ class LayerNormGradKernel<platform::CPUDeviceContext, T>
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
auto inv_std = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y"
d_scale->data<T>()[0] =
((x_map - mean_map.replicate(1, right))
.cwiseProduct(var_map.unaryExpr(inv_std).replicate(1, right))
.cwiseProduct(d_y_map))
.sum(); // also can use `y` to get d_scale_map
.sum();
}
if (d_x) {
......
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
......@@ -33,23 +32,24 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op
def _reference_layer_norm_naive(x, scale, beta, epsilon):
def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
old_shape = x.shape
N = x.shape[0]
D = reduce(mul, old_shape, 1) / N
N = reduce(mul, old_shape[0:begin_norm_axis], 1)
D = reduce(mul, old_shape[begin_norm_axis:len(old_shape)], 1)
x.shape = [N, D]
mean = np.mean(x, axis=1)
var = np.var(x, axis=1) + epsilon
output = scale * np.divide((x - mean.reshape([N, 1])),
(np.sqrt(var)).reshape([N, 1])) + beta
output.shape = old_shape
x.shape = old_shape
return output, mean, var
def _reference_layer_norm_grad(x, grad_y, scale, mean, var, epsilon):
def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
x_shape = x.shape
N = x_shape[0]
D = reduce(mul, x_shape, 1) / N
N = reduce(mul, x_shape[0:begin_norm_axis], 1)
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
grad_y.shape = [N, D]
x.shape = [N, D]
mean.shape = [N, 1]
......@@ -140,7 +140,9 @@ class TestLayerNormdOp(OpTest):
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def test_forward_backward(self):
def test_with_place(place, shape):
def test_with_place(place, shape, begin_norm_axis=1):
assert begin_norm_axis > 0 and begin_norm_axis < len(
shape), 'begin_norm_axis must be between 0 and len(shape)-1.'
# attr
epsilon = 0.00001
x_shape = shape
......@@ -152,13 +154,13 @@ class TestLayerNormdOp(OpTest):
# run forward
y_out, saved_mean, var_ref = _reference_layer_norm_naive(
x_val, scale_val, bias_val, epsilon)
x_val, scale_val, bias_val, epsilon, begin_norm_axis)
# for gradient test
y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_layer_norm_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon)
x_val, y_grad, scale_val, saved_mean, var_ref, begin_norm_axis)
scope = core.Scope()
......@@ -185,7 +187,8 @@ class TestLayerNormdOp(OpTest):
Mean="Mean",
Variance="Variance",
# attrs
epsilon=epsilon)
epsilon=epsilon,
begin_norm_axis=begin_norm_axis)
layer_norm_op.run(scope, place)
......@@ -228,7 +231,8 @@ class TestLayerNormdOp(OpTest):
places.append(core.CUDAPlace(0))
for place in places:
test_with_place(place, [2, 3, 4, 5])
test_with_place(place, [2, 3, 4, 5], begin_norm_axis=1)
test_with_place(place, [2, 3, 4, 5], begin_norm_axis=3)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册