test_dist_pnorm.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import paddle
18
from paddle.distributed.fleet import auto
19 20 21 22 23 24
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward

paddle.enable_static()


25
def make_program_dp2_axis_None():
26 27 28 29 30
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
31 32 33
        auto.shard_tensor(
            x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
        )
34 35 36 37
        tmp_0 = paddle.norm(x, p=2)
    return main_program, start_program, tmp_0


38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
def make_program_dp2_axis_0():
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
        auto.shard_tensor(
            x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
        )
        tmp_0 = paddle.norm(x, p=2, axis=0)
    return main_program, start_program, tmp_0


def make_program_dp2_axis_1():
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
        auto.shard_tensor(
            x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
        )
        tmp_0 = paddle.norm(x, p=2, axis=1)
    return main_program, start_program, tmp_0


64 65 66 67 68 69
def make_program_serial():
    main_program = paddle.fluid.Program()
    start_program = paddle.fluid.Program()
    with paddle.static.program_guard(main_program, start_program):
        x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
        x.stop_gradient = False
70 71 72
        auto.shard_tensor(
            x, auto.ProcessMesh([0], dim_names=["x"]), [None, None, None]
        )
73 74 75 76 77 78 79
        tmp_0 = paddle.norm(x, p=2)
    return main_program, start_program, tmp_0


def parallelizer(program_func, rank):
    from paddle.distributed.auto_parallel.completion import Completer
    from paddle.distributed.auto_parallel.dist_context import DistributedContext
80
    from paddle.distributed.auto_parallel.partitioner import Partitioner
81 82 83 84 85 86 87 88 89 90

    main_program, start_program, loss = program_func()

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    with program_guard(main_program, start_program):
        params_grads = append_backward(
91 92
            loss, distop_context=dist_context.dist_op_context
        )
93 94 95
    completer.complete_backward_annotation(main_program)
    dist_context.block_state.parse_backward_blocks(main_program)
    partitioner = Partitioner(dist_context, rank)
96 97 98
    dist_main_prog, _, _ = partitioner.partition(
        main_program, start_program, []
    )
99 100 101 102 103

    return dist_main_prog, dist_context


class TestDistPNorm(unittest.TestCase):
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    def prepare(self, func):
        self.dist_main_prog, self.dist_context = parallelizer(func, 0)
        self.ops = self.dist_main_prog.global_block().ops

    def test_dist_pnorm(self):
        pass


class TestDistPNormDP(TestDistPNorm):
    def test_dist_pnorm(self):
        self.prepare(make_program_dp2_axis_None)
        self.check_program()

    def check_program(self):
        op_types = []
        for op in self.ops:
            op_types.append(op.type)
            op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
            if op.type == "p_norm":
                assert op_dist_attr.impl_type == "p_norm"
            if op.type in ["p_norm", "p_norm_grad"]:
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
126
                    assert set(input_attr.dims_mapping) == {-1}
127
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
128
                    assert set(output_attr.dims_mapping) == {-1}
129 130 131
            if op.type == 'c_allgather':
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
                    assert input_attr.dims_mapping[0] == 0
132
                    assert set(input_attr.dims_mapping[1:]) == {-1}
133
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
134
                    assert set(output_attr.dims_mapping) == {-1}
135 136
            if op.type == 'slice':
                for input_attr in op_dist_attr.inputs_dist_attrs.values():
137
                    assert set(input_attr.dims_mapping) == {-1}
138 139
                for output_attr in op_dist_attr.outputs_dist_attrs.values():
                    assert output_attr.dims_mapping[0] == 0
140
                    assert set(output_attr.dims_mapping[1:]) == {-1}
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
        assert op_types == [
            "c_allgather",
            "p_norm",
            "fill_constant",
            "p_norm_grad",
            "slice",
        ]


class TestDistPNormDP1(TestDistPNormDP):
    def test_dist_pnorm(self):
        self.prepare(make_program_dp2_axis_0)
        self.check_program()


class TestDistPNormSerial(TestDistPNorm):
    def test_dist_pnorm(self):
        self.prepare(make_program_serial)
        for op in self.ops:
            op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
            assert op_dist_attr.impl_type == "default"


class TestDistPNormDPAxis1(TestDistPNorm):
    def test_dist_pnorm(self):
        self.prepare(make_program_dp2_axis_1)
        for op in self.ops:
            op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
169 170 171 172 173
            assert op_dist_attr.impl_type == "default"


if __name__ == "__main__":
    unittest.main()