提交 5e041966 编写于 作者: X Xiaoda Zhang

add a new vritualdataset cell for three inputs

上级 e6273ce3
......@@ -18,7 +18,7 @@ Wrap cells for networks.
Use the Wrapper to combine the loss or build the training steps.
"""
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, DataWrapper, \
ParameterUpdate, GetNextSingleOp
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer
......@@ -33,5 +33,6 @@ __all__ = [
"DistributedGradReducer",
"ParameterUpdate",
"DynamicLossScaleUpdateCell",
"FixedLossScaleUpdateCell"
"FixedLossScaleUpdateCell",
"VirtualDatasetCellTriple"
]
......@@ -278,6 +278,36 @@ class _VirtualDatasetCell(Cell):
return self._backbone(data_, label_)
class VirtualDatasetCellTriple(Cell):
"""
Wrap the network with virtual dataset to convert data parallel layout to model parallel layout.
VirtualDatasetCellTriple is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs
of VirtualDatasetCellTriple are distributed in data parallel pattern, tensor redistribution Primitives is inserted
dynamically during the graph compile process.
Note:
Only used in semi auto parallel and auto parallel mode. There are three inputs, as contrary to two inputs in
_VirtualDatasetCell.
Args:
backbone (Cell): The target network to wrap.
Examples:
>>> net = Net()
>>> net = VirtualDatasetCellTriple(net)
"""
def __init__(self, backbone):
super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
self._backbone = backbone
self._virtual_dataset = _VirtualDataset()
def construct(self, a, b, c):
a_, b_, c_ = self._virtual_dataset(a, b, c)
return self._backbone(a_, b_, c_)
class WithEvalCell(Cell):
r"""
Cell that returns loss, output and label for evaluation.
......
......@@ -21,6 +21,7 @@ import mindspore as ms
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from mindspore import context
......@@ -73,6 +74,29 @@ def test_virtual_dataset_3_input():
net.set_auto_parallel()
_executor.compile(net, x, y, b)
def test_virtualdataset_cell_3_inputs():
class Net(nn.Cell):
def __init__(self, strategy0, strategy1, strategy2, strategy3):
super().__init__()
self.matmul1 = P.MatMul().set_strategy(strategy1)
self.matmul2 = P.MatMul().set_strategy(strategy2)
self.gelu = P.Gelu().set_strategy(strategy3)
def construct(self, x, y, b):
out = self.gelu(self.matmul1(x, y))
out = self.matmul2(out, b)
return out
net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None))))
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
net.set_auto_parallel()
_executor.compile(net, x, y, b)
if __name__ == '__main__':
test_virtual_dataset_3_input()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册