prune.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
__all__ = []

17

18
class ProgramDeps:
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    def __init__(self, block, start_vars, end_vars):
        self._block = block
        # vars where to start to build the deps
        self._start_vars = start_vars
        # vars where to stop to build the deps
        self._end_vars = end_vars
        # var name -> op idxs which depends on this var
        self._var_to_use_op = {}
        # sub block deps which is a subset of this topo
        self._sub_block_deps = {}
        # var name -> op idxs which generate var
        self._var_to_generate_op = {}
        self._should_removed_var = set()
        self._father_block_deps = None
        self._build_deps()

    def get_sub_block_deps(self, idx):
        if idx in self._sub_block_deps:
            return self._sub_block_deps[idx]
        else:
            return None

    def get_var_deps(self, var_name):
        if var_name in self._var_to_use_op:
            return self._var_to_use_op[var_name]
        else:
            return None

47 48 49
    def _build_deps(
        self,
    ):
50

51 52 53 54 55 56
        for var_name in self._start_vars:
            self._var_to_use_op[var_name] = []
            self._var_to_generate_op[var_name] = []

        for idx, op in enumerate(self._block.ops):
            if op.type in [
57 58 59
                "c_allreduce_sum",
                "c_sync_comm_stream",
                "c_calc_comm_stream",
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
            ]:
                continue
            input_vars = op.desc.input_arg_names()
            output_vars = op.desc.output_arg_names()
            deps_reduce = False
            for input_name in input_vars:
                if input_name in self._var_to_use_op:
                    deps_reduce = True
            if not deps_reduce:
                continue
            for input_name in input_vars:
                if input_name in self._var_to_use_op:
                    self._var_to_use_op[input_name].append(idx)
            for output_name in output_vars:
                if output_name not in self._var_to_use_op:
                    self._var_to_use_op[output_name] = []
                if output_name not in self._var_to_generate_op:
                    self._var_to_generate_op[output_name] = [idx]
                else:
                    self._var_to_generate_op[output_name].append(idx)
            if op.type == "conditional_block":
                # subblock
82
                assert op.desc.has_attr("sub_block")
83 84 85
                subblock_idx = op.desc.attr("sub_block").id
                subblock_deps = ProgramDeps(
                    self._block.program.block(subblock_idx),
86 87 88
                    op.desc.input_arg_names(),
                    op.desc.output_arg_names(),
                )
89 90 91 92 93 94 95 96 97 98 99
                self._sub_block_deps[subblock_idx] = subblock_deps
                subblock_deps._father_block_deps = self

    def crop_input_var_from_op(self, op_idx, var_name):
        if var_name in self._var_to_use_op:
            # update var -> dep_var_op
            if self._var_to_use_op[var_name] != []:
                if op_idx not in self._var_to_use_op[var_name]:
                    raise ValueError(
                        "op_idx: {} is not in self._var_to_use_op[{}], "
                        "self._var_to_use_op[{}] is {}".format(
100 101 102 103 104 105
                            op_idx,
                            var_name,
                            var_name,
                            self._var_to_use_op[var_name],
                        )
                    )
106 107 108 109
                self._var_to_use_op[var_name].remove(op_idx)
            # update _should_removed_var
            if var_name in self._start_vars:
                self._should_removed_var.discard(var_name)
110 111 112
            elif (
                self._var_to_use_op[var_name] == []
            ):  # no more deps of this var
113
                self._should_removed_var.add(var_name)
114 115 116 117
            elif (
                self._var_to_generate_op[var_name][-1]
                >= self._var_to_use_op[var_name][-1]
            ):
118 119 120 121 122 123 124
                # there are circle in the graph
                self._should_removed_var.add(var_name)
            else:  # input_name should not be deleted
                self._should_removed_var.discard(var_name)

    def crop_output_var_from_op(self, op_idx, var_name):
        if var_name in self._var_to_generate_op:
125
            assert op_idx in self._var_to_generate_op[var_name]
126 127
            self._var_to_generate_op[var_name].remove(op_idx)
        if self._block.has_var(var_name):
128 129 130 131
            if (
                var_name not in self._var_to_generate_op
                or self._var_to_generate_op[var_name] == []
            ):
132 133
                self._block._remove_var(var_name, sync=False)

134
    def remove_op(self, op_idx, reserved_vars=None):
135 136 137
        # update deps
        op = self._block.ops[op_idx]
        for input_name in op.desc.input_arg_names():
138 139
            if reserved_vars is not None and input_name in reserved_vars:
                continue
140 141
            self.crop_input_var_from_op(op_idx, input_name)
        for output_name in op.desc.output_arg_names():
142 143
            if reserved_vars is not None and output_name in reserved_vars:
                continue
144 145 146 147 148
            self.crop_output_var_from_op(op_idx, output_name)
        self._block._remove_op(op_idx, sync=False)

    def should_remove_op(self, op_idx):
        op = self._block.ops[op_idx]
149 150 151 152 153 154 155

        # NOTE: At present, it is found that the OP without output is
        # only send_v2 and partial_send op, which will be used in
        # all device
        if len(op.desc.output_arg_names()) == 0:
            return False

156 157 158 159
        for output_name in op.desc.output_arg_names():
            if output_name not in self._should_removed_var:
                return False
        return True