dist_op.py 15.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
16

17 18
import paddle
from paddle.fluid.framework import Variable
19 20 21 22 23 24

from .dist_attribute import (
    OperatorDistributedAttribute,
    append_op_input_suffix,
    append_op_output_suffix,
)
Z
zhaoyingli 已提交
25
from .utils import (
26
    __no_shape_var_type__,
Z
zhaoyingli 已提交
27 28 29
    convert_to_shard_spec,
    verify_shard_spec,
)
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82


class DistributedOperator:
    def __init__(self, serial_op, dist_attr=None):
        self._serial_op = serial_op
        self._serial_inputs = {}
        self._serial_outputs = {}
        self._dist_attr = None
        # Reuse the dist_attr setter to initialize _dist_attr
        self.dist_attr = dist_attr

    @property
    def serial_op(self):
        return self._serial_op

    @property
    def dist_attr(self):
        return self._dist_attr

    @dist_attr.setter
    def dist_attr(self, dist_attr):
        if self._dist_attr is None:
            self._dist_attr = OperatorDistributedAttribute()
        # Create new dist_attr related to current serial_op
        dist_attr = self._filter_dist_attr(dist_attr)
        # Append suffix to mark the inputs or outputs
        if isinstance(dist_attr, dict):
            # Copy the keys since we may add new ones
            for key in list(dist_attr.keys()):
                if isinstance(key, Variable):
                    if key.name in self._serial_op.input_arg_names:
                        dist_attr[append_op_input_suffix(key.name)] = True
                    if key.name in self._serial_op.output_arg_names:
                        dist_attr[append_op_output_suffix(key.name)] = True
        self._dist_attr.init(dist_attr)
        self._init_default_dist_attr()

    def get_serial_input(self, name):
        return self._serial_inputs.get(name, None)

    def get_serial_output(self, name):
        return self._serial_outputs.get(name, None)

    def _init_default_dist_attr(self):
        for tensor_name in self._serial_op.input_arg_names:
            if self._serial_op.type == "create_py_reader":
                tensor = None
            else:
                tensor = self._serial_op.block._var_recursive(tensor_name)
            self._serial_inputs[tensor_name] = tensor
            if tensor is None:
                tensor_shape = []
            else:
Z
zhaoyingli 已提交
83
                if tensor.type in __no_shape_var_type__:
84 85 86 87 88
                    tensor_shape = []
                else:
                    tensor_shape = tensor.shape
            if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
                tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
89 90 91
                self._dist_attr.set_input_dims_mapping(
                    tensor_name, tensor_dims_mapping
                )
92 93
        for tensor_name in self._serial_op.output_arg_names:
            tensor = self._serial_op.block._var_recursive(tensor_name)
Z
zhaoyingli 已提交
94
            if tensor.type in __no_shape_var_type__:
95 96 97 98 99 100
                tensor_shape = []
            else:
                tensor_shape = tensor.shape
            self._serial_outputs[tensor_name] = tensor
            if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
                tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
101 102 103
                self._dist_attr.set_output_dims_mapping(
                    tensor_name, tensor_dims_mapping
                )
104 105
        if self._dist_attr.op_type is None:
            self._dist_attr.op_type = self.serial_op.type
106 107 108
        if self._dist_attr.impl_type is None:
            self._dist_attr.impl_type = "default"
        if self._dist_attr.impl_idx is None:
109
            self._dist_attr.impl_idx = 0
110 111
        if self._dist_attr.is_recompute is None:
            self._dist_attr.is_recompute = False
112 113 114 115 116 117 118 119 120

    def _filter_dist_attr(self, dist_attr):
        if dist_attr is None:
            return None
        new_dist_attr = None
        if isinstance(dist_attr, dict):
            new_dist_attr = {}
            for key, value in dist_attr.items():
                if isinstance(key, Variable):
121 122 123 124
                    if (
                        key.name in self._serial_op.input_arg_names
                        or key.name in self._serial_op.output_arg_names
                    ):
125 126 127 128 129 130 131 132 133 134
                        new_dist_attr[key] = value
                else:
                    new_dist_attr[key] = value
        elif isinstance(dist_attr, OperatorDistributedAttribute):
            new_dist_attr = copy.deepcopy(dist_attr)
            new_dist_attr._inputs_dist_attrs.clear()
            new_dist_attr._outputs_dist_attrs.clear()
            for tensor_name in self._serial_op.input_arg_names:
                tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
                if tensor_dist_attr:
135 136 137
                    new_dist_attr.set_input_dist_attr(
                        tensor_name, tensor_dist_attr
                    )
138 139 140
            for tensor_name in self._serial_op.output_arg_names:
                tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
                if tensor_dist_attr:
141 142 143
                    new_dist_attr.set_output_dist_attr(
                        tensor_name, tensor_dist_attr
                    )
144 145 146 147 148
        else:
            assert False, "Cannot recognize the {} parameter.".format(dist_attr)
        return new_dist_attr

    def validate_dist_attr(self):
149
        if "read" in self.serial_op.type or "while" == self.serial_op.type:
150 151 152 153
            return True
        for name in self.serial_op.input_arg_names:
            input_dist_attr = self.dist_attr.get_input_dist_attr(name)
            dims_mapping = input_dist_attr.dims_mapping
Z
zhaoyingli 已提交
154
            if self.get_serial_input(name).type in __no_shape_var_type__:
155 156 157
                shape = []
            else:
                shape = self.get_serial_input(name).shape
158 159 160 161
            if len(shape) != len(dims_mapping):
                return False
            for i in range(len(dims_mapping)):
                if dims_mapping[i] < -1 or dims_mapping[i] >= len(
162
                    self.dist_attr.process_mesh.shape
163
                ):
164
                    return False
165
            for i in range(len(self.dist_attr.process_mesh.shape)):
166 167 168 169 170 171 172 173
                if dims_mapping.count(i) > 1:
                    return False
            if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
                return False

        for name in self.serial_op.output_arg_names:
            output_dist_attr = self.dist_attr.get_output_dist_attr(name)
            dims_mapping = output_dist_attr.dims_mapping
Z
zhaoyingli 已提交
174
            if self.get_serial_output(name).type in __no_shape_var_type__:
175 176 177
                shape = []
            else:
                shape = self.get_serial_output(name).shape
178 179 180 181
            if len(shape) != len(dims_mapping):
                return False
            for i in range(len(dims_mapping)):
                if dims_mapping[i] < -1 or dims_mapping[i] >= len(
182
                    self.dist_attr.process_mesh.shape
183
                ):
184
                    return False
185
            for i in range(len(self.dist_attr.process_mesh.shape)):
186 187 188 189 190 191 192
                if dims_mapping.count(i) > 1:
                    return False
            if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
                return False
        return True

    def __str__(self):
193 194 195
        str = "{{op type: {}, op id: {}".format(
            self.serial_op.desc.type(), self.serial_op.desc.id()
        )
196 197 198 199 200 201 202 203

        # str += ", {}".format(self.dist_attr)
        # return str

        if self.dist_attr.is_annotated("process_mesh"):
            annotated_str = "annotated"
        else:
            annotated_str = "non-annotated"
204 205 206
        str += ", process_mesh ({}): {}".format(
            annotated_str, self.dist_attr.process_mesh
        )
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

        for arg_name in self.serial_op.desc.input_arg_names():
            dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
            if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
                annotated_str = "annotated"
            else:
                annotated_str = "non-annotated"
            if self.get_serial_input(arg_name) is not None:
                if self.get_serial_input(arg_name).is_parameter:
                    is_parameter_str = "parameter"
                else:
                    is_parameter_str = "non-parameter"
            else:
                is_parameter_str = "non-parameter"
            str += ", {}'s dims_mapping (input, {}, {}): {}".format(
222 223
                arg_name, annotated_str, is_parameter_str, dims_mapping
            )
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238

        for arg_name in self.serial_op.desc.output_arg_names():
            dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
            if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
                annotated_str = "annotated"
            else:
                annotated_str = "non-annotated"
            if self.get_serial_output(arg_name) is not None:
                if self.get_serial_output(arg_name).is_parameter:
                    is_parameter_str = "parameter"
                else:
                    is_parameter_str = "non-parameter"
            else:
                is_parameter_str = "non-parameter"
            str += ", {}'s dims_mapping (output, {}, {}): {}".format(
239 240
                arg_name, annotated_str, is_parameter_str, dims_mapping
            )
241 242 243

        str += ", pipeline stage: {}".format(None)

244
        str += ", dist_impl idx: {} , dist_impl type {} }}".format(
245 246
            self.dist_attr._impl_idx, self.dist_attr._impl_type
        )
247 248 249

        return str

Z
zhaoyingli 已提交
250 251 252 253 254
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
255 256 257 258 259
            if (
                k == "_serial_op"
                or k == "_serial_inputs"
                or k == "_serial_outputs"
            ):
Z
zhaoyingli 已提交
260 261 262 263 264
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

265

266
class DistributedOperatorHelper:
267 268 269
    def __init__(
        self, serial_op, process_mesh, in_dims_mappings, out_dims_mappings
    ):
270 271 272 273
        self._serial_op = serial_op
        self._process_mesh = process_mesh
        self._in_dims_mappings = in_dims_mappings
        self._out_dims_mappings = out_dims_mappings
274 275

    def __call__(self, *args, **kwargs):
276 277 278
        tensor_to_dims_mapping = {}
        index = 0
        if self._in_dims_mappings:
279 280 281 282 283
            assert len(args) + len(kwargs) == len(
                self._in_dims_mappings
            ), "The length of dims_mapping {} does not matching the length output {}.".format(
                len(self._in_dims_mappings), len(args) + len(kwargs)
            )
284 285 286 287 288 289 290 291 292
        for arg in args:
            if isinstance(arg, Variable) and self._in_dims_mappings:
                tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
            index += 1
        for arg in kwargs.values() and self._in_dims_mappings:
            if isinstance(arg, Variable):
                tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
            index += 1

293 294 295
        default_prog = paddle.fluid.default_main_program()
        cur_block = default_prog.current_block()
        op_size = len(cur_block.ops)
296
        output = self._serial_op(*args, **kwargs)
297
        new_op_size = len(cur_block.ops)
298 299 300 301 302 303 304 305 306

        if isinstance(output, tuple) or isinstance(output, list):
            new_output = list(output)
        elif isinstance(output, Variable):
            new_output = [output]
        else:
            raise ValueError("Unrecognized outpout.")

        if self._out_dims_mappings:
307 308 309 310 311
            assert len(new_output) == len(
                self._out_dims_mappings
            ), "The length of dims_mapping {} does not matching the length output {}.".format(
                len(self._out_dims_mappings), len(new_output)
            )
312 313 314 315 316
        for i, item in enumerate(new_output):
            if isinstance(item, Variable) and self._out_dims_mappings:
                tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]

        from .dist_context import get_default_distributed_context
317

318 319
        default_dist_ctx = get_default_distributed_context()
        for idx in range(op_size, new_op_size):
320
            op = cur_block.ops[idx]
321 322 323 324 325
            dist_op = DistributedOperator(op)
            for name in dist_op.serial_op.input_arg_names:
                if name in tensor_to_dims_mapping.keys():
                    tensor = dist_op.get_serial_input(name)
                    tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
326 327
                        name
                    )
328 329 330 331
                    dims_mapping = tensor_to_dims_mapping[name]
                    if tensor is None:
                        tensor_shape = []
                    else:
Z
zhaoyingli 已提交
332
                        if tensor.type in __no_shape_var_type__:
333 334 335 336 337 338
                            tensor_shape = []
                        else:
                            tensor_shape = tensor.shape
                    if dims_mapping is not None:
                        dims_mapping = tensor_to_dims_mapping[name]
                        shard_spec = convert_to_shard_spec(
339 340 341 342 343 344 345
                            dims_mapping, self._process_mesh
                        )
                        assert verify_shard_spec(
                            shard_spec, tensor_shape, self._process_mesh
                        ), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
                            name, shard_spec, tensor_shape, self._process_mesh
                        )
346 347 348 349 350 351
                        tensor_dist_attr.dims_mapping = dims_mapping
                        tensor_dist_attr.mark_annotated("dims_mapping")
            for name in dist_op.serial_op.output_arg_names:
                if name in tensor_to_dims_mapping.keys():
                    tensor = dist_op.get_serial_output(name)
                    tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
352 353
                        name
                    )
354 355 356 357
                    dims_mapping = tensor_to_dims_mapping[name]
                    if tensor is None:
                        tensor_shape = []
                    else:
Z
zhaoyingli 已提交
358
                        if tensor.type in __no_shape_var_type__:
359 360 361 362 363 364
                            tensor_shape = []
                        else:
                            tensor_shape = tensor.shape
                    if dims_mapping is not None:
                        dims_mapping = tensor_to_dims_mapping[name]
                        shard_spec = convert_to_shard_spec(
365 366 367 368 369 370 371
                            dims_mapping, self._process_mesh
                        )
                        assert verify_shard_spec(
                            shard_spec, tensor_shape, self._process_mesh
                        ), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
                            name, shard_spec, tensor_shape, self._process_mesh
                        )
372 373 374 375 376
                        tensor_dist_attr.dims_mapping = dims_mapping
                        tensor_dist_attr.mark_annotated("dims_mapping")
            dist_op.dist_attr.process_mesh = self._process_mesh
            if self._process_mesh is not None:
                dist_op.dist_attr.mark_annotated("process_mesh")
377
            default_dist_ctx.add_dist_op_for_program(dist_op)
378

Z
zhaoyingli 已提交
379
        return output