diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 30f1e25a34947a5c8f4eb4f376abc0f885b99b0c..ca161cfad0beeefb799a0a310a9f08ed9c5d153d 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -51,16 +51,18 @@ class ControlDepend(Primitive): >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() - >>> self.global_step = mindspore.Parameter(initializer(0, [1]), name="global_step") - >>> self.rate = 0.2 >>> self.control_depend = P.ControlDepend() + >>> self.softmax = P.Softmax() >>> - >>> def construct(self, x): - >>> data = self.rate * self.global_step + x - >>> added_global_step = self.global_step + 1 - >>> self.global_step = added_global_step - >>> self.control_depend(data, added_global_step) - >>> return data + >>> def construct(self, x, y): + >>> mul = x * y + >>> softmax = self.softmax(x) + >>> ret = self.control_depend(mul, softmax) + >>> return ret + >>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32) + >>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32) + >>> net = Net() + >>> output = net(x, y) """ @prim_attr_register