test_dist_pnorm.py 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import paddle
18
from paddle.distributed.fleet import auto
19 20 21 22 23 24
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward

paddle.enable_static()


25
def make_program_dp2_axis_None():
26 27 28 29 30
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
31 32 33
        auto.shard_tensor(
            x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
        )
34 35 36 37
        tmp_0 = paddle.norm(x, p=2)
    return main_program, start_program, tmp_0


38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
def make_program_dp2_axis_0():
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
        auto.shard_tensor(
            x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
        )
        tmp_0 = paddle.norm(x, p=2, axis=0)
    return main_program, start_program, tmp_0


def make_program_dp2_axis_1():
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
        auto.shard_tensor(
            x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
        )
        tmp_0 = paddle.norm(x, p=2, axis=1)
    return main_program, start_program, tmp_0


64 65 66 67 68 69
def make_program_serial():
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
70 71 72
        auto.shard_tensor(
            x, auto.ProcessMesh([0], dim_names=["x"]), [None, None, None]
        )
73 74 75 76 77
        tmp_0 = paddle.norm(x, p=2)
    return main_program, start_program, tmp_0


def parallelizer(program_func, rank):
78 79 80 81 82
    from paddle.distributed.auto_parallel.static.completion import Completer
    from paddle.distributed.auto_parallel.static.dist_context import (
        DistributedContext,
    )
    from paddle.distributed.auto_parallel.static.partitioner import Partitioner
83 84 85 86 87 88 89 90 91 92

    main_program, start_program, loss = program_func()

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    with program_guard(main_program, start_program):
        params_grads = append_backward(
93 94
            loss, distop_context=dist_context.dist_op_context
        )
95 96 97
    completer.complete_backward_annotation(main_program)
    dist_context.block_state.parse_backward_blocks(main_program)
    partitioner = Partitioner(dist_context, rank)
98 99 100
    dist_main_prog, _, _ = partitioner.partition(
        main_program, start_program, []
    )
101 102 103 104 105

    return dist_main_prog, dist_context


class TestDistPNorm(unittest.TestCase):
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    def prepare(self, func):
        self.dist_main_prog, self.dist_context = parallelizer(func, 0)
        self.ops = self.dist_main_prog.global_block().ops

    def test_dist_pnorm(self):
        pass


class TestDistPNormDP(TestDistPNorm):
    def test_dist_pnorm(self):
        self.prepare(make_program_dp2_axis_None)
        self.check_program()

    def check_program(self):
        op_types = []
        for op in self.ops:
            op_types.append(op.type)
            op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
            if op.type == "p_norm":
                assert op_dist_attr.impl_type == "p_norm"
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
127
                    assert set(input_attr.dims_mapping) == {-1}
128 129 130 131 132 133 134 135 136 137 138
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
                    if len(output_attr.dims_mapping) == 0:
                        assert output_attr.dims_mapping == []
                    else:
                        assert set(output_attr.dims_mapping) == {-1}
            if op.type == "p_norm_grad":
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
                    if len(input_attr.dims_mapping) == 0:
                        assert input_attr.dims_mapping == []
                    else:
                        assert set(input_attr.dims_mapping) == {-1}
139
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
140
                    assert set(output_attr.dims_mapping) == {-1}
141 142 143
            if op.type == 'c_allgather':
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
                    assert input_attr.dims_mapping[0] == 0
144
                    assert set(input_attr.dims_mapping[1:]) == {-1}
145
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
146
                    assert set(output_attr.dims_mapping) == {-1}
147 148
            if op.type == 'slice':
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
149
                    assert set(input_attr.dims_mapping) == {-1}
150 151
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
                    assert output_attr.dims_mapping[0] == 0
152
                    assert set(output_attr.dims_mapping[1:]) == {-1}
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        assert op_types == [
            "c_allgather",
            "p_norm",
            "fill_constant",
            "p_norm_grad",
            "slice",
        ]


class TestDistPNormDP1(TestDistPNormDP):
    def test_dist_pnorm(self):
        self.prepare(make_program_dp2_axis_0)
        self.check_program()


class TestDistPNormSerial(TestDistPNorm):
    def test_dist_pnorm(self):
        self.prepare(make_program_serial)
        for op in self.ops:
            op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
            assert op_dist_attr.impl_type == "default"


class TestDistPNormDPAxis1(TestDistPNorm):
    def test_dist_pnorm(self):
        self.prepare(make_program_dp2_axis_1)
        for op in self.ops:
            op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
181 182 183 184 185
            assert op_dist_attr.impl_type == "default"


if __name__ == "__main__":
    unittest.main()