control_ops.py 6.8 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""control_ops"""

from ...common import dtype as mstype
19 20
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
Z
zhunaipan 已提交
21 22 23 24 25 26 27 28
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register


class ControlDepend(Primitive):
    """
    Adds control dependency relation between source and destination operation.

    In many cases, we need to control the execution order of operations. ControlDepend is designed for this.
S
simson 已提交
29
    ControlDepend will instruct the execution engine to run the operations in a specific order. ControlDepend
Z
zhunaipan 已提交
30 31 32
    tells the engine that the destination operations should depend on the source operation which means the source
    operations should be executed before the destination.

H
huangdongrun 已提交
33 34
    Note:
        This operation does not work in `PYNATIVE_MODE`.
Z
zhunaipan 已提交
35
    Args:
S
simson 已提交
36 37
        depend_mode (int): Use 0 for a normal dependency relation. Use 1 to depends on operations which using Parameter
        as its input. Default: 0.
Z
zhunaipan 已提交
38 39 40 41

    Inputs:
        - **src** (Any) - The source input. It can be a tuple of operations output or a single operation output. We do
          not concern about the input data, but concern about the operation that generates the input data.
S
simson 已提交
42
          If `depend_mode` is 1 and the source input is Parameter, we will try to find the operations that
Z
zhunaipan 已提交
43 44 45
          used the parameter as input.
        - **dst** (Any) - The destination input. It can be a tuple of operations output or a single operation output.
          We do not concern about the input data, but concern about the operation that generates the input data.
S
simson 已提交
46
          If `depend_mode` is 1 and the source input is Parameter, we will try to find the operations that
Z
zhunaipan 已提交
47 48 49 50 51 52 53 54
          used the parameter as input.

    Outputs:
        Bool. This operation has no actual data output, it will be used to setup the order of relative operations.

    Examples:
        >>> class Net(nn.Cell):
        >>>     def __init__(self):
万万没想到 已提交
55
        >>>         super(Net, self).__init__()
56
        >>>         self.control_depend = P.ControlDepend()
57
        >>>         self.softmax = P.Softmax()
Z
zhunaipan 已提交
58
        >>>
59 60 61 62 63 64 65 66 67
        >>>     def construct(self, x, y):
        >>>         mul = x * y
        >>>         softmax = self.softmax(x)
        >>>         ret = self.control_depend(mul, softmax)
        >>>         return ret
        >>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
        >>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
        >>> net = Net()
        >>> output = net(x, y)
Z
zhunaipan 已提交
68 69 70 71 72
    """

    @prim_attr_register
    def __init__(self, depend_mode=0):
        """init"""
H
huangdongrun 已提交
73 74
        validator.check_int_range(
            "depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
Z
zhunaipan 已提交
75 76 77 78 79 80 81 82 83

    def __call__(self, src, dst):
        return src


class GeSwitch(PrimitiveWithInfer):
    """
    Adds control switch to data.

S
simson 已提交
84
    Switch data flows into false or true branch depending on the condition. If the condition is true,
Z
zhunaipan 已提交
85 86 87
    the true branch will be activated, or vise verse.

    Inputs:
J
jiangjinsheng 已提交
88
        - **data** (Union[Tensor, Number]) - The data to be used for switch control.
Z
zhunaipan 已提交
89 90 91 92 93 94 95 96 97
        - **pred** (Tensor) - It should be a scalar whose type is bool and shape is `()`, It is used as condition for
          switch control.
    Outputs:
        tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data.
        The false_output connects with the false_branch and the true_output connects with the true_branch.

    Examples:
        >>> class Net(nn.Cell):
        >>> 	def __init__(self):
万万没想到 已提交
98
        >>>         super(Net, self).__init__()
Z
zhunaipan 已提交
99 100
        >>>         self.square = P.Square()
        >>>         self.add = P.TensorAdd()
101
        >>>         self.value = Tensor(np.full((1), 3), mindspore.float32)
Z
zhunaipan 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115
        >>>         self.switch = P.GeSwitch()
        >>>         self.merge = P.Merge()
        >>>         self.less = P.Less()
        >>>
        >>>     def construct(self, x, y):
        >>>         cond = self.less(x, y)
        >>>         st1, sf1 = self.switch(x, cond)
        >>>         st2, sf2 = self.switch(y, cond)
        >>>         add_ret = self.add(st1, st2)
        >>>         st3, sf3 = self.switch(self.value, cond)
        >>>         sq_ret = self.square(sf3)
        >>>         ret = self.merge((add_ret, sq_ret))
        >>>         return ret[0]
        >>>
万万没想到 已提交
116 117
        >>> x = Tensor(10.0, dtype=mindspore.float32)
        >>> y = Tensor(5.0, dtype=mindspore.float32)
Z
zhunaipan 已提交
118 119 120 121 122 123 124 125 126 127 128 129
        >>> net = Net()
        >>> output = net(x, y)
    """

    @prim_attr_register
    def __init__(self):
        """init"""

    def __call__(self, data, pred):
        raise NotImplementedError

    def infer_shape(self, data, pred):
130
        validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name)
Z
zhunaipan 已提交
131 132 133
        return (data, data)

    def infer_dtype(self, data_type, pred_type):
H
huangdongrun 已提交
134 135 136 137
        validator.check_subclass(
            "data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
        validator.check_tensor_type_same(
            {"pred": pred_type}, [mstype.bool_], self.name)
Z
zhunaipan 已提交
138 139 140 141 142 143 144 145 146 147
        return (data_type, data_type)


class Merge(PrimitiveWithInfer):
    """
    Merges all input data to one.

    One and only one of the inputs should be selected as the output

    Inputs:
J
jiangjinsheng 已提交
148
        - **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements should have same data type.
Z
zhunaipan 已提交
149 150 151

    Outputs:
        tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
152 153 154 155 156 157

    Examples:
        >>> merge = P.Merge()
        >>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32)
        >>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32)
        >>> result = merge((input_x, input_y))
Z
zhunaipan 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170
    """

    @prim_attr_register
    def __init__(self):
        """init"""

    def __call__(self, *args):
        raise NotImplementedError

    def infer_shape(self, inputs):
        return (inputs[0], [1])

    def infer_dtype(self, inputs):
171 172 173 174
        args = {}
        for i, item in enumerate(inputs):
            args['inputs[%d]' % i] = item

J
jiangjinsheng 已提交
175
        validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
Z
zhunaipan 已提交
176
        return (inputs[0], mstype.int32)