group_norm.cpp 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"

namespace megdnn {
namespace test {

TEST_F(NAIVE, GROUPNORM_FORWARD) {
    Checker<GroupNorm> checker(handle(), true);

    GroupNorm::Param param;
    param.affine = true;
    param.group = 3;
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {2, 3, 2, 1}, dtype::Float32(),
                            {3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587,
                             0.0711, -0.1169, 0.2509, -0.2393, 0.0876}),  // input
                    TensorValue({3}, dtype::Float32(), {1., 1., 1.}),     // hx
                    TensorValue({3}, dtype::Float32(), {0., 0., 0.}),     // cx
                    {},
                    {},
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {2, 3, 2, 1}, dtype::Float32(),
                            {1., -1., -1., 1., -1., 1., -1., 1., -0.9999, 0.9999,
                             -0.9998, 0.9998}),  // output
                    TensorValue(
                            {2, 3}, dtype::Float32(),
                            {1.7135, -0.1645, -0.0107, -0.9938, 0.067,
                             -0.0758}),  // mean
                    TensorValue(
                            {2, 3}, dtype::Float32(),
                            {2.5742, 0.1772, 1.6358, 1.1340, 0.0338, 0.0267}),  // var
            });
    checker.set_param(param).exect(
            Testcase{
                    TensorValue(
                            {1, 3, 1, 2}, dtype::Float32(),
                            {-2.4348, -1.7948, 0.5223, 0.0932, -0.2955,
                             -0.0492}),                                // input
                    TensorValue({3}, dtype::Float32(), {1., 1., 1.}),  // hx
                    TensorValue({3}, dtype::Float32(), {0., 0., 0.}),  // cx
                    {},
                    {},
                    {}},
            Testcase{
                    {},
                    {},
                    {},
                    TensorValue(
                            {1, 3, 1, 2}, dtype::Float32(),
                            {-0.9999, 0.9999, 0.9999, -0.9999, -0.9997,
                             0.9997}),  // output
                    TensorValue(
                            {1, 3}, dtype::Float32(),
                            {-2.1148, 0.3077, -0.1724}),  // mean
                    TensorValue(
                            {1, 3}, dtype::Float32(), {0.1023, 0.0460, 0.0151}),  // var
            });
}

}  // namespace test
}  // namespace megdnn