# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import unittest import numpy as np from op_test import OpTest def l2_norm(x, axis, epsilon): x2 = x**2 s = np.sum(x2, axis=axis, keepdims=True) r = np.sqrt(s + epsilon) y = x / np.broadcast_to(r, x.shape) return y, r class TestNormOp(OpTest): def setUp(self): self.op_type = "norm" self.init_test_case() x = np.random.random(self.shape).astype("float64") y, norm = l2_norm(x, self.axis, self.epsilon) self.inputs = {'X': x} self.attrs = {'epsilon': self.epsilon, 'axis': self.axis} self.outputs = {'Out': y, 'Norm': norm} def test_check_output(self): self.check_output() def test_check_grad(self): self.check_grad(['X'], 'Out') def init_test_case(self): self.shape = [2, 3, 4, 5] self.axis = 1 self.epsilon = 1e-8 class TestNormOp2(TestNormOp): def init_test_case(self): self.shape = [5, 3, 9, 7] self.axis = 0 self.epsilon = 1e-8 class TestNormOp3(TestNormOp): def init_test_case(self): self.shape = [5, 3, 2, 7] self.axis = -1 self.epsilon = 1e-8 class TestNormOp4(TestNormOp): def init_test_case(self): self.shape = [128, 1024, 14, 14] self.axis = 2 self.epsilon = 1e-8 def test_check_grad(self): # since the gradient check is very slow in large shape, so skip check_grad pass class TestNormOp5(TestNormOp): def init_test_case(self): self.shape = [2048, 2048] self.axis = 1 self.epsilon = 1e-8 def test_check_grad(self): # since the gradient check is very slow in large shape, so skip check_grad pass if __name__ == '__main__': unittest.main()