未验证 提交 4b3589fb 编写于 作者: Z zhaoyingli 提交者: GitHub

2.4/fix engine build (#47462)

* update codestyle

* [AutoParallel] fix fp16 for subblock (#47189)

* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment

* [AutoParallel] fix engine _build and cost method (#47263)

* fix engine build method

* fix import

* update engine cost

* update raise error

* update cmakelist

* revert optimizer

* revert optimizer

* fix unittest

* fix unittest
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 f93e9a58
...@@ -20,9 +20,28 @@ class AdamOpCost(CompOpCost): ...@@ -20,9 +20,28 @@ class AdamOpCost(CompOpCost):
OP_TYPE = "adam" OP_TYPE = "adam"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(AdamOpCost, self).__init__(op=op, super(AdamOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ArgsortOpCost(CompOpCost):
OP_TYPE = "argsort"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ArgsortOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -39,9 +58,9 @@ class AssignOpCost(CompOpCost): ...@@ -39,9 +58,9 @@ class AssignOpCost(CompOpCost):
OP_TYPE = "assign" OP_TYPE = "assign"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(AssignOpCost, self).__init__(op=op, super(AssignOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -58,9 +77,9 @@ class AssignValueOpCost(CompOpCost): ...@@ -58,9 +77,9 @@ class AssignValueOpCost(CompOpCost):
OP_TYPE = "assign_value" OP_TYPE = "assign_value"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(AssignValueOpCost, self).__init__(op=op, super(AssignValueOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -77,9 +96,9 @@ class BeamSearchOpCost(CompOpCost): ...@@ -77,9 +96,9 @@ class BeamSearchOpCost(CompOpCost):
OP_TYPE = "beam_search" OP_TYPE = "beam_search"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(BeamSearchOpCost, self).__init__(op=op, super(BeamSearchOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -96,9 +115,9 @@ class BeamSearchDecodeOpCost(CompOpCost): ...@@ -96,9 +115,9 @@ class BeamSearchDecodeOpCost(CompOpCost):
OP_TYPE = "beam_search_decode" OP_TYPE = "beam_search_decode"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(BeamSearchDecodeOpCost, self).__init__(op=op, super(BeamSearchDecodeOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -115,9 +134,9 @@ class CastOpCost(CompOpCost): ...@@ -115,9 +134,9 @@ class CastOpCost(CompOpCost):
OP_TYPE = "cast" OP_TYPE = "cast"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(CastOpCost, self).__init__(op=op, super(CastOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -134,9 +153,9 @@ class ConcatOpCost(CompOpCost): ...@@ -134,9 +153,9 @@ class ConcatOpCost(CompOpCost):
OP_TYPE = "concat" OP_TYPE = "concat"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ConcatOpCost, self).__init__(op=op, super(ConcatOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -153,9 +172,9 @@ class DropoutOpCost(CompOpCost): ...@@ -153,9 +172,9 @@ class DropoutOpCost(CompOpCost):
OP_TYPE = "dropout" OP_TYPE = "dropout"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutOpCost, self).__init__(op=op, super(DropoutOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -172,9 +191,9 @@ class DropoutGradOpCost(CompOpCost): ...@@ -172,9 +191,9 @@ class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad" OP_TYPE = "dropout_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op, super(DropoutGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -191,9 +210,9 @@ class ElementwiseAddOpCost(CompOpCost): ...@@ -191,9 +210,9 @@ class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add" OP_TYPE = "elementwise_add"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseAddOpCost, self).__init__(op=op, super(ElementwiseAddOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -210,9 +229,9 @@ class ElementwiseAddGradOpCost(CompOpCost): ...@@ -210,9 +229,9 @@ class ElementwiseAddGradOpCost(CompOpCost):
OP_TYPE = "elementwise_add_grad" OP_TYPE = "elementwise_add_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseAddGradOpCost, self).__init__(op=op, super(ElementwiseAddGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -229,9 +248,9 @@ class ElementwiseDivOpCost(CompOpCost): ...@@ -229,9 +248,9 @@ class ElementwiseDivOpCost(CompOpCost):
OP_TYPE = "elementwise_div" OP_TYPE = "elementwise_div"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseDivOpCost, self).__init__(op=op, super(ElementwiseDivOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -248,9 +267,9 @@ class ElementwiseDivGradOpCost(CompOpCost): ...@@ -248,9 +267,9 @@ class ElementwiseDivGradOpCost(CompOpCost):
OP_TYPE = "elementwise_div_grad" OP_TYPE = "elementwise_div_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseDivGradOpCost, self).__init__(op=op, super(ElementwiseDivGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -267,9 +286,9 @@ class ElementwiseMulOpCost(CompOpCost): ...@@ -267,9 +286,9 @@ class ElementwiseMulOpCost(CompOpCost):
OP_TYPE = "elementwise_mul" OP_TYPE = "elementwise_mul"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseMulOpCost, self).__init__(op=op, super(ElementwiseMulOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -286,9 +305,9 @@ class ElementwiseMulGradOpCost(CompOpCost): ...@@ -286,9 +305,9 @@ class ElementwiseMulGradOpCost(CompOpCost):
OP_TYPE = "elementwise_mul_grad" OP_TYPE = "elementwise_mul_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseMulGradOpCost, self).__init__(op=op, super(ElementwiseMulGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -305,9 +324,9 @@ class ElementwiseSubOpCost(CompOpCost): ...@@ -305,9 +324,9 @@ class ElementwiseSubOpCost(CompOpCost):
OP_TYPE = "elementwise_sub" OP_TYPE = "elementwise_sub"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseSubOpCost, self).__init__(op=op, super(ElementwiseSubOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -324,9 +343,28 @@ class ElementwiseSubGradOpCost(CompOpCost): ...@@ -324,9 +343,28 @@ class ElementwiseSubGradOpCost(CompOpCost):
OP_TYPE = "elementwise_sub_grad" OP_TYPE = "elementwise_sub_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseSubGradOpCost, self).__init__(op=op, super(ElementwiseSubGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class EqualOpCost(CompOpCost):
OP_TYPE = "equal"
def __init__(self, op=None, op_desc=None, cluster=None):
super(EqualOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -343,9 +381,9 @@ class EmbeddingOpCost(CompOpCost): ...@@ -343,9 +381,9 @@ class EmbeddingOpCost(CompOpCost):
OP_TYPE = "c_embedding" OP_TYPE = "c_embedding"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(EmbeddingOpCost, self).__init__(op=op, super(EmbeddingOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -362,9 +400,9 @@ class EmbeddingGradOpCost(CompOpCost): ...@@ -362,9 +400,9 @@ class EmbeddingGradOpCost(CompOpCost):
OP_TYPE = "c_embedding_grad" OP_TYPE = "c_embedding_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(EmbeddingGradOpCost, self).__init__(op=op, super(EmbeddingGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -381,9 +419,9 @@ class FillConstantOpCost(CompOpCost): ...@@ -381,9 +419,9 @@ class FillConstantOpCost(CompOpCost):
OP_TYPE = "fill_constant" OP_TYPE = "fill_constant"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(FillConstantOpCost, self).__init__(op=op, super(FillConstantOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -400,9 +438,9 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): ...@@ -400,9 +438,9 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
OP_TYPE = "fill_constant_batch_size_like" OP_TYPE = "fill_constant_batch_size_like"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(FillConstantBatchSizeLikeOpCost, self).__init__(op=op, super(FillConstantBatchSizeLikeOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -419,8 +457,9 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): ...@@ -419,8 +457,9 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle" OP_TYPE = "fused_softmax_mask_upper_triangle"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost, super(FusedSoftmaxMaskUpperTriangleOpCost, self).__init__(
self).__init__(op=op, op_desc=op_desc, cluster=cluster) op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -437,8 +476,9 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): ...@@ -437,8 +476,9 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad" OP_TYPE = "fused_softmax_mask_upper_triangle_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost, super(FusedSoftmaxMaskUpperTriangleGradOpCost, self).__init__(
self).__init__(op=op, op_desc=op_desc, cluster=cluster) op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -455,9 +495,9 @@ class GatherOpCost(CompOpCost): ...@@ -455,9 +495,9 @@ class GatherOpCost(CompOpCost):
OP_TYPE = "gather" OP_TYPE = "gather"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(GatherOpCost, self).__init__(op=op, super(GatherOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -474,9 +514,9 @@ class GeluOpCost(CompOpCost): ...@@ -474,9 +514,9 @@ class GeluOpCost(CompOpCost):
OP_TYPE = "gelu" OP_TYPE = "gelu"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(GeluOpCost, self).__init__(op=op, super(GeluOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -493,9 +533,9 @@ class GeluGradOpCost(CompOpCost): ...@@ -493,9 +533,9 @@ class GeluGradOpCost(CompOpCost):
OP_TYPE = "gelu_grad" OP_TYPE = "gelu_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(GeluGradOpCost, self).__init__(op=op, super(GeluGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -512,9 +552,9 @@ class GreaterEqualOpCost(CompOpCost): ...@@ -512,9 +552,9 @@ class GreaterEqualOpCost(CompOpCost):
OP_TYPE = "greater_equal" OP_TYPE = "greater_equal"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(GreaterEqualOpCost, self).__init__(op=op, super(GreaterEqualOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -531,9 +571,9 @@ class IncrementOpCost(CompOpCost): ...@@ -531,9 +571,9 @@ class IncrementOpCost(CompOpCost):
OP_TYPE = "increment" OP_TYPE = "increment"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(IncrementOpCost, self).__init__(op=op, super(IncrementOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -546,9 +586,9 @@ class IsEmptyOpCost(CompOpCost): ...@@ -546,9 +586,9 @@ class IsEmptyOpCost(CompOpCost):
OP_TYPE = "is_empty" OP_TYPE = "is_empty"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(IsEmptyOpCost, self).__init__(op=op, super(IsEmptyOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -561,9 +601,9 @@ class LayerNormOpCost(CompOpCost): ...@@ -561,9 +601,9 @@ class LayerNormOpCost(CompOpCost):
OP_TYPE = "layer_norm" OP_TYPE = "layer_norm"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LayerNormOpCost, self).__init__(op=op, super(LayerNormOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -580,9 +620,9 @@ class LayerNormGradOpCost(CompOpCost): ...@@ -580,9 +620,9 @@ class LayerNormGradOpCost(CompOpCost):
OP_TYPE = "layer_norm_grad" OP_TYPE = "layer_norm_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LayerNormGradOpCost, self).__init__(op=op, super(LayerNormGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -599,9 +639,9 @@ class LessThanOpCost(CompOpCost): ...@@ -599,9 +639,9 @@ class LessThanOpCost(CompOpCost):
OP_TYPE = "less_than" OP_TYPE = "less_than"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LessThanOpCost, self).__init__(op=op, super(LessThanOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -618,9 +658,9 @@ class LogicalNotOpCost(CompOpCost): ...@@ -618,9 +658,9 @@ class LogicalNotOpCost(CompOpCost):
OP_TYPE = "logical_not" OP_TYPE = "logical_not"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LogicalNotOpCost, self).__init__(op=op, super(LogicalNotOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -637,9 +677,9 @@ class LogicalAndOpCost(CompOpCost): ...@@ -637,9 +677,9 @@ class LogicalAndOpCost(CompOpCost):
OP_TYPE = "logical_and" OP_TYPE = "logical_and"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LogicalAndOpCost, self).__init__(op=op, super(LogicalAndOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -656,9 +696,9 @@ class LodResetOpCost(CompOpCost): ...@@ -656,9 +696,9 @@ class LodResetOpCost(CompOpCost):
OP_TYPE = "lod_reset" OP_TYPE = "lod_reset"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LodResetOpCost, self).__init__(op=op, super(LodResetOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -692,9 +732,9 @@ class LookupTableV2OpCost(CompOpCost): ...@@ -692,9 +732,9 @@ class LookupTableV2OpCost(CompOpCost):
OP_TYPE = "lookup_table_v2" OP_TYPE = "lookup_table_v2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LookupTableV2OpCost, self).__init__(op=op, super(LookupTableV2OpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -711,9 +751,9 @@ class LookupTableV2GradOpCost(CompOpCost): ...@@ -711,9 +751,9 @@ class LookupTableV2GradOpCost(CompOpCost):
OP_TYPE = "lookup_table_v2_grad" OP_TYPE = "lookup_table_v2_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(LookupTableV2GradOpCost, self).__init__(op=op, super(LookupTableV2GradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -730,9 +770,9 @@ class MatmulOpCost(CompOpCost): ...@@ -730,9 +770,9 @@ class MatmulOpCost(CompOpCost):
OP_TYPE = "matmul" OP_TYPE = "matmul"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulOpCost, self).__init__(op=op, super(MatmulOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -749,9 +789,9 @@ class MatmulGradOpCost(CompOpCost): ...@@ -749,9 +789,9 @@ class MatmulGradOpCost(CompOpCost):
OP_TYPE = "matmul_grad" OP_TYPE = "matmul_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulGradOpCost, self).__init__(op=op, super(MatmulGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -768,9 +808,9 @@ class MatmulV2OpCost(CompOpCost): ...@@ -768,9 +808,9 @@ class MatmulV2OpCost(CompOpCost):
OP_TYPE = "matmul_v2" OP_TYPE = "matmul_v2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulV2OpCost, self).__init__(op=op, super(MatmulV2OpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -787,9 +827,9 @@ class MatmulV2GradOpCost(CompOpCost): ...@@ -787,9 +827,9 @@ class MatmulV2GradOpCost(CompOpCost):
OP_TYPE = "matmul_v2_grad" OP_TYPE = "matmul_v2_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulV2GradOpCost, self).__init__(op=op, super(MatmulV2GradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -806,9 +846,9 @@ class MemcpyOpCost(CompOpCost): ...@@ -806,9 +846,9 @@ class MemcpyOpCost(CompOpCost):
OP_TYPE = "memcpy" OP_TYPE = "memcpy"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(MemcpyOpCost, self).__init__(op=op, super(MemcpyOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -842,9 +882,9 @@ class MulGradOpCost(CompOpCost): ...@@ -842,9 +882,9 @@ class MulGradOpCost(CompOpCost):
OP_TYPE = "mul_grad" OP_TYPE = "mul_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(MulGradOpCost, self).__init__(op=op, super(MulGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -861,9 +901,9 @@ class OneHotOpCost(CompOpCost): ...@@ -861,9 +901,9 @@ class OneHotOpCost(CompOpCost):
OP_TYPE = "one_hot" OP_TYPE = "one_hot"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(OneHotOpCost, self).__init__(op=op, super(OneHotOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -880,9 +920,9 @@ class ReadFromArrayOpCost(CompOpCost): ...@@ -880,9 +920,9 @@ class ReadFromArrayOpCost(CompOpCost):
OP_TYPE = "read_from_array" OP_TYPE = "read_from_array"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ReadFromArrayOpCost, self).__init__(op=op, super(ReadFromArrayOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -899,9 +939,9 @@ class ReduceSumOpCost(CompOpCost): ...@@ -899,9 +939,9 @@ class ReduceSumOpCost(CompOpCost):
OP_TYPE = "reduce_sum" OP_TYPE = "reduce_sum"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceSumOpCost, self).__init__(op=op, super(ReduceSumOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -918,9 +958,9 @@ class ReduceSumGradOpCost(CompOpCost): ...@@ -918,9 +958,9 @@ class ReduceSumGradOpCost(CompOpCost):
OP_TYPE = "reduce_sum_grad" OP_TYPE = "reduce_sum_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceSumGradOpCost, self).__init__(op=op, super(ReduceSumGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -937,9 +977,9 @@ class Reshape2OpCost(CompOpCost): ...@@ -937,9 +977,9 @@ class Reshape2OpCost(CompOpCost):
OP_TYPE = "reshape2" OP_TYPE = "reshape2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(Reshape2OpCost, self).__init__(op=op, super(Reshape2OpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -956,9 +996,9 @@ class Reshape2GradOpCost(CompOpCost): ...@@ -956,9 +996,9 @@ class Reshape2GradOpCost(CompOpCost):
OP_TYPE = "reshape2_grad" OP_TYPE = "reshape2_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(Reshape2GradOpCost, self).__init__(op=op, super(Reshape2GradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -975,9 +1015,9 @@ class ReduceMeanOpCost(CompOpCost): ...@@ -975,9 +1015,9 @@ class ReduceMeanOpCost(CompOpCost):
OP_TYPE = "reduce_mean" OP_TYPE = "reduce_mean"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceMeanOpCost, self).__init__(op=op, super(ReduceMeanOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -994,9 +1034,9 @@ class ReduceMeanGradOpCost(CompOpCost): ...@@ -994,9 +1034,9 @@ class ReduceMeanGradOpCost(CompOpCost):
OP_TYPE = "reduce_mean_grad" OP_TYPE = "reduce_mean_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceMeanGradOpCost, self).__init__(op=op, super(ReduceMeanGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1013,9 +1053,9 @@ class SamplingIdOpCost(CompOpCost): ...@@ -1013,9 +1053,9 @@ class SamplingIdOpCost(CompOpCost):
OP_TYPE = "sampling_id" OP_TYPE = "sampling_id"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SamplingIdOpCost, self).__init__(op=op, super(SamplingIdOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1032,9 +1072,9 @@ class ScaleOpCost(CompOpCost): ...@@ -1032,9 +1072,9 @@ class ScaleOpCost(CompOpCost):
OP_TYPE = "scale" OP_TYPE = "scale"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(ScaleOpCost, self).__init__(op=op, super(ScaleOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1051,9 +1091,9 @@ class SliceOpCost(CompOpCost): ...@@ -1051,9 +1091,9 @@ class SliceOpCost(CompOpCost):
OP_TYPE = "slice" OP_TYPE = "slice"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SliceOpCost, self).__init__(op=op, super(SliceOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1070,9 +1110,9 @@ class SoftmaxOpCost(CompOpCost): ...@@ -1070,9 +1110,9 @@ class SoftmaxOpCost(CompOpCost):
OP_TYPE = "softmax" OP_TYPE = "softmax"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxOpCost, self).__init__(op=op, super(SoftmaxOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1089,9 +1129,9 @@ class SoftmaxGradOpCost(CompOpCost): ...@@ -1089,9 +1129,9 @@ class SoftmaxGradOpCost(CompOpCost):
OP_TYPE = "softmax_grad" OP_TYPE = "softmax_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxGradOpCost, self).__init__(op=op, super(SoftmaxGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1108,9 +1148,9 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost): ...@@ -1108,9 +1148,9 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
OP_TYPE = "softmax_with_cross_entropy" OP_TYPE = "softmax_with_cross_entropy"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxWithCrossEntropyOpCost, self).__init__(op=op, super(SoftmaxWithCrossEntropyOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1127,9 +1167,9 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost): ...@@ -1127,9 +1167,9 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
OP_TYPE = "softmax_with_cross_entropy_grad" OP_TYPE = "softmax_with_cross_entropy_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxWithCrossEntropyGradOpCost, self).__init__(op=op, super(SoftmaxWithCrossEntropyGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1146,9 +1186,9 @@ class SplitOpCost(CompOpCost): ...@@ -1146,9 +1186,9 @@ class SplitOpCost(CompOpCost):
OP_TYPE = "split" OP_TYPE = "split"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SplitOpCost, self).__init__(op=op, super(SplitOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1165,9 +1205,9 @@ class Squeeze2OpCost(CompOpCost): ...@@ -1165,9 +1205,9 @@ class Squeeze2OpCost(CompOpCost):
OP_TYPE = "squeeze2" OP_TYPE = "squeeze2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(Squeeze2OpCost, self).__init__(op=op, super(Squeeze2OpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1184,9 +1224,9 @@ class SquareOpCost(CompOpCost): ...@@ -1184,9 +1224,9 @@ class SquareOpCost(CompOpCost):
OP_TYPE = "square" OP_TYPE = "square"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SquareOpCost, self).__init__(op=op, super(SquareOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1203,9 +1243,9 @@ class SquareGradOpCost(CompOpCost): ...@@ -1203,9 +1243,9 @@ class SquareGradOpCost(CompOpCost):
OP_TYPE = "square_grad" OP_TYPE = "square_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(SquareGradOpCost, self).__init__(op=op, super(SquareGradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1239,9 +1279,9 @@ class TopKOpCost(CompOpCost): ...@@ -1239,9 +1279,9 @@ class TopKOpCost(CompOpCost):
OP_TYPE = "top_k" OP_TYPE = "top_k"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(TopKOpCost, self).__init__(op=op, super(TopKOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1258,9 +1298,9 @@ class Transpose2OpCost(CompOpCost): ...@@ -1258,9 +1298,9 @@ class Transpose2OpCost(CompOpCost):
OP_TYPE = "transpose2" OP_TYPE = "transpose2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(Transpose2OpCost, self).__init__(op=op, super(Transpose2OpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1277,9 +1317,9 @@ class Transpose2GradOpCost(CompOpCost): ...@@ -1277,9 +1317,9 @@ class Transpose2GradOpCost(CompOpCost):
OP_TYPE = "transpose2_grad" OP_TYPE = "transpose2_grad"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(Transpose2GradOpCost, self).__init__(op=op, super(Transpose2GradOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1296,9 +1336,9 @@ class Unsqueeze2OpCost(CompOpCost): ...@@ -1296,9 +1336,9 @@ class Unsqueeze2OpCost(CompOpCost):
OP_TYPE = "unsqueeze2" OP_TYPE = "unsqueeze2"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(Unsqueeze2OpCost, self).__init__(op=op, super(Unsqueeze2OpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
...@@ -1315,9 +1355,9 @@ class WriteToArrayOpCost(CompOpCost): ...@@ -1315,9 +1355,9 @@ class WriteToArrayOpCost(CompOpCost):
OP_TYPE = "write_to_array" OP_TYPE = "write_to_array"
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super(WriteToArrayOpCost, self).__init__(op=op, super(WriteToArrayOpCost, self).__init__(
op_desc=op_desc, op=op, op_desc=op_desc, cluster=cluster
cluster=cluster) )
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self): def calc_flops(self):
......
...@@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor ...@@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor
class CostEstimator: class CostEstimator:
_sepical_op_type = ["fused_attention", "fused_feedforward"] _sepical_op_type = ["fused_attention", "fused_feedforward"]
def __init__(self, def __init__(
program, self, program, cluster, mode="modeling", rank=None, loop_count=10
cluster, ):
mode="modeling",
rank=None,
loop_count=10):
self._program = program self._program = program
self._cluster = cluster self._cluster = cluster
self._check_mode(mode) self._check_mode(mode)
...@@ -41,7 +38,8 @@ class CostEstimator: ...@@ -41,7 +38,8 @@ class CostEstimator:
self._loop_count = loop_count self._loop_count = loop_count
self._global_cost = Cost() self._global_cost = Cost()
self._local_cost_mapping = {} self._local_cost_mapping = {}
self._detailed_cost = OrderedDict( self._detailed_cost = (
OrderedDict()
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {} self._bubble_time_mapping = {}
self._ordered_ops = [] self._ordered_ops = []
...@@ -106,7 +104,8 @@ class CostEstimator: ...@@ -106,7 +104,8 @@ class CostEstimator:
def _check_mode(self, mode): def _check_mode(self, mode):
if mode not in ["modeling", "profiling"]: if mode not in ["modeling", "profiling"]:
raise ValueError( raise ValueError(
"Just support modeling and profiling, but got {}".format(mode)) "Just support modeling and profiling, but got {}".format(mode)
)
def _is_special_var_name(self, var_name): def _is_special_var_name(self, var_name):
special_var_name = ["lod_tensor_blocking_queue_0"] special_var_name = ["lod_tensor_blocking_queue_0"]
...@@ -116,6 +115,7 @@ class CostEstimator: ...@@ -116,6 +115,7 @@ class CostEstimator:
def _estimate_core(self, dist_context, resharder, block): def _estimate_core(self, dist_context, resharder, block):
from ..reshard import get_var_with_recursion from ..reshard import get_var_with_recursion
ops = block.ops ops = block.ops
loop_count = None loop_count = None
if block.desc.id != self.program.global_block().desc.id: if block.desc.id != self.program.global_block().desc.id:
...@@ -132,8 +132,9 @@ class CostEstimator: ...@@ -132,8 +132,9 @@ class CostEstimator:
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
continue continue
if op.type in [ if op.type in [
"create_py_reader", "create_double_buffer_reader", "create_py_reader",
"read" "create_double_buffer_reader",
"read",
]: ]:
continue continue
...@@ -172,14 +173,16 @@ class CostEstimator: ...@@ -172,14 +173,16 @@ class CostEstimator:
max_time = rank_cost.time max_time = rank_cost.time
for rank in group_ranks: for rank in group_ranks:
self.local_cost( self.local_cost(rank).time = (
rank).time = max_time + cost.time max_time + cost.time
)
if rank not in self._bubble_time_mapping: if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += ( self._bubble_time_mapping[rank] += (
max_time - cost_time[rank]) max_time - cost_time[rank]
)
for rank in local_comp_cost: for rank in local_comp_cost:
for comp_cost in local_comp_cost[rank]: for comp_cost in local_comp_cost[rank]:
...@@ -191,15 +194,19 @@ class CostEstimator: ...@@ -191,15 +194,19 @@ class CostEstimator:
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.processes
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
op_dist_attr.impl_type) op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx] dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, dist_op_cost = dist_impl.calc_cost(
dist_context, self.cluster) op.attr('op_role'), dist_op, dist_context, self.cluster
)
detail["dist_op_cost"] = dist_op_cost detail["dist_op_cost"] = dist_op_cost
if dist_op_cost is None: if dist_op_cost is None:
assert dist_op.serial_op.type in CostEstimator._sepical_op_type assert (
dist_op.serial_op.type in CostEstimator._sepical_op_type
)
continue continue
for item in dist_op_cost: for item in dist_op_cost:
if isinstance(item, list): if isinstance(item, list):
...@@ -217,12 +224,14 @@ class CostEstimator: ...@@ -217,12 +224,14 @@ class CostEstimator:
if max_time < rank_cost.time: if max_time < rank_cost.time:
max_time = rank_cost.time max_time = rank_cost.time
for rank in group_ranks: for rank in group_ranks:
self.local_cost( self.local_cost(rank).time = (
rank).time = max_time + comm_op_cost.time max_time + comm_op_cost.time
)
if rank not in self._bubble_time_mapping: if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += ( self._bubble_time_mapping[rank] += (
max_time - cost_time[rank]) max_time - cost_time[rank]
)
elif isinstance(item, dict): elif isinstance(item, dict):
# Op just one # Op just one
for rank in processes: for rank in processes:
...@@ -247,8 +256,11 @@ class CostEstimator: ...@@ -247,8 +256,11 @@ class CostEstimator:
dtype_factor = 8 dtype_factor = 8
elif dtype == paddle.float32 or dtype == paddle.int32: elif dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4 dtype_factor = 4
elif dtype == paddle.float16 or dtype == paddle.bfloat16 \ elif (
or dtype == paddle.int16: dtype == paddle.float16
or dtype == paddle.bfloat16
or dtype == paddle.int16
):
dtype_factor = 2 dtype_factor = 2
elif dtype == paddle.int8 or dtype == paddle.uint8: elif dtype == paddle.int8 or dtype == paddle.uint8:
dtype_factor = 1 dtype_factor = 1
...@@ -270,8 +282,9 @@ class CostEstimator: ...@@ -270,8 +282,9 @@ class CostEstimator:
memories = {} memories = {}
self.max_memories = {} self.max_memories = {}
var_info = { var_info = (
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} {}
) # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
for block in self.program.blocks: for block in self.program.blocks:
for op in block.ops: for op in block.ops:
...@@ -280,18 +293,22 @@ class CostEstimator: ...@@ -280,18 +293,22 @@ class CostEstimator:
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_double_buffer_reader", "read" "create_py_reader",
"create_double_buffer_reader",
"read",
]: ]:
continue continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name) var_name
)
if var_name not in var_info: if var_name not in var_info:
var_info[var_name] = {} var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh, key = _convert_pm_and_dm_to_str(
input_dims_mapping) process_mesh, input_dims_mapping
)
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
# It is even partition now # It is even partition now
...@@ -300,21 +317,27 @@ class CostEstimator: ...@@ -300,21 +317,27 @@ class CostEstimator:
global_sizes = var.shape global_sizes = var.shape
dtype = var.dtype dtype = var.dtype
sizes = DistributedTensor.get_local_sizes( sizes = DistributedTensor.get_local_sizes(
global_sizes, input_dims_mapping, process_mesh.topology, global_sizes,
process_mesh.processes) input_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype) sizes, dtype
)
if "position" not in var_info[var_name][key]: if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = [] var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id) var_info[var_name][key]["position"].append(op_id)
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name) var_name
)
if var_name not in var_info: if var_name not in var_info:
var_info[var_name] = {} var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh, key = _convert_pm_and_dm_to_str(
output_dims_mapping) process_mesh, output_dims_mapping
)
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
...@@ -322,10 +345,14 @@ class CostEstimator: ...@@ -322,10 +345,14 @@ class CostEstimator:
global_sizes = var.shape global_sizes = var.shape
dtype = var.dtype dtype = var.dtype
sizes = DistributedTensor.get_local_sizes( sizes = DistributedTensor.get_local_sizes(
global_sizes, output_dims_mapping, global_sizes,
process_mesh.topology, process_mesh.processes) output_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype) sizes, dtype
)
if "position" not in var_info[var_name][key]: if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = [] var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id) var_info[var_name][key]["position"].append(op_id)
...@@ -333,7 +360,9 @@ class CostEstimator: ...@@ -333,7 +360,9 @@ class CostEstimator:
has_used_vars = set() has_used_vars = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_double_buffer_reader", "read" "create_py_reader",
"create_double_buffer_reader",
"read",
]: ]:
continue continue
can_free_memories = {} can_free_memories = {}
...@@ -342,9 +371,11 @@ class CostEstimator: ...@@ -342,9 +371,11 @@ class CostEstimator:
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name) var_name
key = _convert_pm_and_dm_to_str(process_mesh, )
input_dims_mapping) key = _convert_pm_and_dm_to_str(
process_mesh, input_dims_mapping
)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
# Not used # Not used
...@@ -364,13 +395,16 @@ class CostEstimator: ...@@ -364,13 +395,16 @@ class CostEstimator:
if process not in can_free_memories: if process not in can_free_memories:
can_free_memories[process] = 0 can_free_memories[process] = 0
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
var_name][key]["memory"] var_name
][key]["memory"]
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name) var_name
key = _convert_pm_and_dm_to_str(process_mesh, )
output_dims_mapping) key = _convert_pm_and_dm_to_str(
process_mesh, output_dims_mapping
)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
# Not used # Not used
...@@ -390,7 +424,8 @@ class CostEstimator: ...@@ -390,7 +424,8 @@ class CostEstimator:
if process not in can_free_memories: if process not in can_free_memories:
can_free_memories[process] = 0 can_free_memories[process] = 0
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
var_name][key]["memory"] var_name
][key]["memory"]
# Calc peak memory # Calc peak memory
for process in memories: for process in memories:
...@@ -414,8 +449,12 @@ class CostEstimator: ...@@ -414,8 +449,12 @@ class CostEstimator:
def estimate(self, dist_context, resharder=None): def estimate(self, dist_context, resharder=None):
self.prepare() self.prepare()
from ..reshard import Resharder from ..reshard import Resharder
resharder = Resharder(self.program, None, self.rank, dist_context,
[]) if resharder is None else resharder resharder = (
Resharder(self.program, None, self.rank, dist_context, [])
if resharder is None
else resharder
)
block = self.program.global_block() block = self.program.global_block()
self._estimate_core(dist_context, resharder, block) self._estimate_core(dist_context, resharder, block)
...@@ -447,7 +486,7 @@ class CostEstimator: ...@@ -447,7 +486,7 @@ class CostEstimator:
memories = [ memories = [
int(item // 1e6) for item in list(self.max_memories.values()) int(item // 1e6) for item in list(self.max_memories.values())
] ]
for memory in (memories + header): for memory in memories + header:
if len(str(memory)) > max_len: if len(str(memory)) > max_len:
max_len = len(str(memory)) max_len = len(str(memory))
max_len += 4 # for pretty print of center max_len += 4 # for pretty print of center
...@@ -477,7 +516,7 @@ class CostEstimator: ...@@ -477,7 +516,7 @@ class CostEstimator:
max_len = 0 max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"] header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header): for memory in vals + header:
if len(str(memory)) > max_len: if len(str(memory)) > max_len:
max_len = len(str(memory)) max_len = len(str(memory))
max_len += 4 # for pretty print of center max_len += 4 # for pretty print of center
...@@ -507,50 +546,73 @@ class CostEstimator: ...@@ -507,50 +546,73 @@ class CostEstimator:
def get_cost_from_engine(engine, mode): def get_cost_from_engine(engine, mode):
from ..utils import to_list from ..utils import to_list
# Construct cost estimator by original main program import copy
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
serial_startup_prog = engine._serial_startup_progs[mode].clone( # Construct cost estimator by original main program
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone( serial_main_prog = (
engine._fwd_main_progs[mode].clone()
if mode in engine._fwd_main_progs
else engine._orig_main_prog.clone()
) )
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts: serial_startup_prog = (
dist_context = engine._dist_contexts[mode] engine._serial_startup_progs[mode].clone()
completer = engine._planners[mode].completer if mode in engine._serial_startup_progs
else engine._orig_startup_prog.clone()
)
losses = (
to_list(engine._loss)
if (
not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)
)
else engine._losses
)
serial_optimizer = copy.deepcopy(engine._orig_optimizer)
if mode in engine._fwd_dist_contexts:
dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
else: else:
from ..completion import Completer
from ..dist_context import DistributedContext from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {}, dist_context = DistributedContext(
{"loss": losses}, engine._cluster, serial_main_prog,
engine._strategy) serial_startup_prog,
completer = Completer(dist_context) serial_optimizer,
completer.complete_forward_annotation() losses,
dist_context.block_state.parse_forward_blocks( {},
dist_context.serial_main_program) {"loss": losses},
engine._cluster,
engine._strategy,
)
from ..completion import Completer
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
if mode == "eval" or mode == "predict": if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster) cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train": elif mode == "train":
from ..parallelizer_v2 import Parallelizer from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward # Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context) parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward # Generate backward
loss_name = dist_context.serial_loss.name loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name) serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog, params_grads = parallelizer._generate_backward(
serial_startup_prog, serial_main_prog, serial_startup_prog, serial_loss
serial_loss) )
# Generate optimizer # Generate optimizer
optimizer_ops = parallelizer._generate_optimizer( optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer, serial_main_prog,
params_grads) serial_startup_prog,
serial_optimizer,
params_grads,
)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster) cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory # Estimate global_cost and max memory
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
import os import os
import copy
import logging import logging
import random import random
import numbers
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
...@@ -41,16 +43,20 @@ from .planner_v2 import Planner ...@@ -41,16 +43,20 @@ from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver from .dist_saver import DistributedSaver
from .dist_loader import DistributedDataLoaderFromGenerator, DistributedDataLoader from .dist_loader import (
from .utils import to_list, get_dist_attr, get_lr DistributedDataLoaderFromGenerator,
DistributedDataLoader,
)
from .process_group import new_process_group, get_all_process_groups from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy from .strategy import Strategy
from .interface import CollectionNames, get_collection from .interface import CollectionNames, get_collection
from ..utils.log_utils import get_logger from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode from .utils import initialize_pg_in_full_mode, get_input_split_info
from .cost.estimate_cost import get_cost_from_engine from .cost.estimate_cost import get_cost_from_engine
from ..utils.log_utils import get_logger
class Engine: class Engine:
""" """
...@@ -115,35 +121,55 @@ class Engine: ...@@ -115,35 +121,55 @@ class Engine:
""" """
def __init__(self, def __init__(
model=None, self,
loss=None, model=None,
optimizer=None, loss=None,
metrics=None, optimizer=None,
cluster=None, metrics=None,
strategy=None): cluster=None,
strategy=None,
if model and not isinstance(model, ):
paddle.nn.Layer) and not callable(model):
if (
model
and not isinstance(model, paddle.nn.Layer)
and not callable(model)
):
raise TypeError( raise TypeError(
"'model must be sub classes of `paddle.nn.Layer` or any callable function." "'model must be sub classes of `paddle.nn.Layer` or any callable function."
) )
self._model = model self._model = model
if (
loss
and not isinstance(loss, (paddle.nn.Layer, Variable))
and not callable(loss)
):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
)
self._loss = loss self._loss = loss
if optimizer and not isinstance( if optimizer and not isinstance(
optimizer, optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
raise TypeError( raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`.") " or `paddle.fluid.optimizer.Optimizer`."
self._optimizer = self._validate_opt(optimizer) )
self._optimizer = validate_opt(optimizer)
self._orig_optimizer = copy.deepcopy(self._optimizer)
metrics = metrics or [] metrics = metrics or []
for metric in to_list(metrics): for metric in to_list(metrics):
assert isinstance(metric, Metric), \ if metric and not isinstance(metric, Metric):
"{} is not sub class of Metric".format( raise TypeError(
metric.__class__.__name__) "{} is not sub class of Metric".format(
metric.__class__.__name__
)
)
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
if cluster and not isinstance(cluster, Cluster): if cluster and not isinstance(cluster, Cluster):
...@@ -158,9 +184,11 @@ class Engine: ...@@ -158,9 +184,11 @@ class Engine:
) )
self._strategy = strategy or Strategy() self._strategy = strategy or Strategy()
self._logger = get_logger(logging.INFO)
if os.getenv("POD_NAME"): if os.getenv("POD_NAME"):
print("Distribute training by paddle.distributed.launch", self._logger.info(
flush=True) "Distribute training by paddle.distributed.launch"
)
fleet.init(is_collective=True) fleet.init(is_collective=True)
self._executor = None self._executor = None
...@@ -168,12 +196,12 @@ class Engine: ...@@ -168,12 +196,12 @@ class Engine:
self._nranks = paddle.distributed.get_world_size() self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver() self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
self._orig_main_prog = static.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {} self._dist_contexts = {}
self._fwd_main_progs = {}
self._fwd_dist_contexts = {}
self._serial_main_progs = {} self._serial_main_progs = {}
self._serial_startup_progs = {} self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs self._dist_main_progs = defaultdict(dict) # dist main programs
...@@ -185,19 +213,20 @@ class Engine: ...@@ -185,19 +213,20 @@ class Engine:
self._has_prepared_reader = { self._has_prepared_reader = {
"train": False, "train": False,
"eval": False, "eval": False,
"predict": False "predict": False,
} }
self._inputs_spec = [] self._inputs_spec = []
self._labels_spec = [] self._labels_spec = []
self._inputs = [] self._inputs = []
self._labels = [] self._labels = []
self._losses = []
self._mode = None
self._skip_build = False self._skip_build = False
self._outside_dataloader = False self._outside_dataloader = False
self._planned_mode = None self._planned_mode = None
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
self._losses = None
self.history = None self.history = None
...@@ -219,9 +248,11 @@ class Engine: ...@@ -219,9 +248,11 @@ class Engine:
inputs = sample[:split] inputs = sample[:split]
labels = sample[split:] labels = sample[split:]
else: else:
raise ValueError( raise TypeError(
"Data should be a Dataset or IterableDatset, but received {}.". "Data should be a Dataset or IterableDatset, but received {}.".format(
format(type(data).__name__)) type(data).__name__
)
)
inputs = to_list(inputs) inputs = to_list(inputs)
labels = to_list(labels) labels = to_list(labels)
...@@ -240,14 +271,20 @@ class Engine: ...@@ -240,14 +271,20 @@ class Engine:
else: else:
specs.append(spec.batch(batch_size)) specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)): elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
_adjust_item_spec(num_shards, spec)
spec = InputSpec.from_tensor(item, name) spec = InputSpec.from_tensor(item, name)
_adjust_item_spec(num_shards, spec)
if batch_size is None: if batch_size is None:
specs.append(spec) specs.append(spec)
else: else:
specs.append(spec.batch(batch_size)) specs.append(spec.batch(batch_size))
else: elif isinstance(item, numbers.Number):
specs.append(InputSpec([batch_size], type(item), name)) specs.append(InputSpec([batch_size], type(item), name))
else:
raise TypeError(
"The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
type(item).__name__
)
)
if inputs is not None: if inputs is not None:
for i, item in enumerate(inputs): for i, item in enumerate(inputs):
...@@ -264,37 +301,41 @@ class Engine: ...@@ -264,37 +301,41 @@ class Engine:
labels_spec = self._validate_spec(labels_spec) labels_spec = self._validate_spec(labels_spec)
return inputs_spec, labels_spec return inputs_spec, labels_spec
def _prepare_data_tensor(self, def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
inputs_spec,
labels_spec,
inputs=None,
labels=None):
if _non_static_mode() or self._dygraph_mode: if _non_static_mode() or self._dygraph_mode:
return None, None raise ValueError("Only support static graph mode.")
inputs_spec = inputs_spec if inputs_spec else []
labels_spec = labels_spec if labels_spec else []
if inputs_spec: if inputs_spec:
assert isinstance(inputs_spec, list), \ assert isinstance(
"inputs should be list, but received {}".format(type(inputs_spec)) inputs_spec, list
if inputs is None: ), "inputs should be list, but received {}".format(
inputs = [s._create_feed_layer() for s in inputs_spec] type(inputs_spec)
else: )
assert isinstance(inputs, list), \ assert isinstance(
"inputs should be list, but received {}".format(type(inputs)) inputs, list
for input_spec, input in zip(inputs_spec, inputs): ), "inputs should be list, but received {}".format(type(inputs))
if input_spec.shape != input.shape: assert len(inputs_spec) == len(
input.desc.set_shape(input_spec.shape) inputs
), "the number of `inputs_spec` should be equal to `inputs`'s."
for input_spec, input in zip(inputs_spec, inputs):
if input_spec.shape != input.shape:
input.desc.set_shape(input_spec.shape)
if labels_spec: if labels_spec:
assert isinstance(labels_spec, list), \ assert isinstance(
"labels should be list, but received {}".format(type(labels_spec)) labels_spec, list
if labels is None: ), "labels should be list, but received {}".format(
labels = [s._create_feed_layer() for s in labels_spec] type(labels_spec)
else: )
assert isinstance(labels, list), \ assert isinstance(
"labels should be list, but received {}".format(type(labels)) labels, list
for label_spec, label in zip(labels_spec, labels): ), "labels should be list, but received {}".format(type(labels))
if label_spec.shape != label.shape: assert len(labels_spec) == len(
label.desc.set_shape(label_spec.shape) labels
), "the number of `labels_spec` should be equal to `labels`'s."
for label_spec, label in zip(labels_spec, labels):
if label_spec.shape != label.shape:
label.desc.set_shape(label_spec.shape)
return inputs, labels return inputs, labels
def _prepare_reader(self): def _prepare_reader(self):
...@@ -304,7 +345,9 @@ class Engine: ...@@ -304,7 +345,9 @@ class Engine:
# NOTE: this list may be changed if Paddle changes the existing rules. # NOTE: this list may be changed if Paddle changes the existing rules.
related_reader_ops = [ related_reader_ops = [
"create_py_reader", "create_double_buffer_reader", "read" "create_py_reader",
"create_double_buffer_reader",
"read",
] ]
# remove the first three ops if multiple run fit/evaluate/predict # remove the first three ops if multiple run fit/evaluate/predict
if dist_main_block.ops[0].type == 'create_py_reader': if dist_main_block.ops[0].type == 'create_py_reader':
...@@ -322,9 +365,9 @@ class Engine: ...@@ -322,9 +365,9 @@ class Engine:
for idx in reversed(reader_op_indices): for idx in reversed(reader_op_indices):
new_op_desc = dist_main_block.desc._prepend_op() new_op_desc = dist_main_block.desc._prepend_op()
new_op_desc.copy_from(dist_main_block.ops[idx].desc) new_op_desc.copy_from(dist_main_block.ops[idx].desc)
new_op = Operator(dist_main_block, new_op = Operator(
new_op_desc, dist_main_block, new_op_desc, type=new_op_desc.type()
type=new_op_desc.type()) )
new_reader_ops.append(new_op) new_reader_ops.append(new_op)
dist_op = DistributedOperator(new_op) dist_op = DistributedOperator(new_op)
dist_context.add_dist_op_for_program(dist_op) dist_context.add_dist_op_for_program(dist_op)
...@@ -355,16 +398,22 @@ class Engine: ...@@ -355,16 +398,22 @@ class Engine:
else: else:
raise ValueError("Unsupported data {}".format(data)) raise ValueError("Unsupported data {}".format(data))
if user_feeds is not None: if user_feeds is not None:
assert isinstance(user_feeds, dict), \ assert isinstance(
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__) user_feeds, dict
), "user_feeds must be a dict, but receive {}".format(
type(user_feeds).__name__
)
for name, data in user_feeds.items(): for name, data in user_feeds.items():
feeds[name] = data feeds[name] = data
return feeds return feeds
def _prepare_fetch(self, user_fetches, mode): def _prepare_fetch(self, user_fetches, mode):
if user_fetches is not None: if user_fetches is not None:
assert isinstance(user_fetches, list), \ assert isinstance(
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__) user_fetches, list
), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__
)
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -396,14 +445,16 @@ class Engine: ...@@ -396,14 +445,16 @@ class Engine:
_process_fetch_group("fetches", var_list) _process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices return fetch_names, fetch_indices
def _prepare_logger(self, def _prepare_logger(
outs, self,
epoch=None, outs,
step=None, epoch=None,
lr=None, step=None,
fetch_names=None, lr=None,
fetch_indices=None, fetch_names=None,
mode=None): fetch_indices=None,
mode=None,
):
logs = {} logs = {}
if epoch is not None: if epoch is not None:
logs["epoch"] = epoch logs["epoch"] = epoch
...@@ -468,11 +519,13 @@ class Engine: ...@@ -468,11 +519,13 @@ class Engine:
self._dygraph_mode = True self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.") self._logger.info("Building model with 'to_static' method.")
inputs_spec = self._inputs_spec self.program_helper = ProgramHelper(
labels_spec = self._labels_spec if self._labels_spec else [] self._model,
self.program_helper = ProgramHelper(self._model, self._loss, self._loss,
self._metrics, inputs_spec, self._metrics,
labels_spec) self._inputs_spec,
self._labels_spec,
)
# build forward main program # build forward main program
self.program_helper.build_program(mode) self.program_helper.build_program(mode)
...@@ -480,16 +533,12 @@ class Engine: ...@@ -480,16 +533,12 @@ class Engine:
serial_main_prog = self.program_helper.main_program serial_main_prog = self.program_helper.main_program
serial_startup_prog = self.program_helper.startup_program serial_startup_prog = self.program_helper.startup_program
inputs = self.program_helper.input_vars self._inputs = self.program_helper.input_vars
self._labels = self.program_helper.label_vars
outputs = self.program_helper.output_vars outputs = self.program_helper.output_vars
labels = self.program_helper.label_vars self._losses = self.program_helper.loss_vars
losses = self.program_helper.loss_vars
self._losses = losses
metrics = self.program_helper.metric_vars metrics = self.program_helper.metric_vars
self._inputs = inputs
self._labels = labels
paddle.enable_static() paddle.enable_static()
else: else:
# build program in static mode # build program in static mode
...@@ -498,27 +547,45 @@ class Engine: ...@@ -498,27 +547,45 @@ class Engine:
return return
outputs = [] outputs = []
losses = []
metrics = [] metrics = []
inputs = self._inputs if self._inputs else [] self._losses = []
labels = self._labels if self._labels else []
serial_main_prog = self._orig_main_prog.clone() serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone() serial_startup_prog = self._orig_startup_prog.clone()
if not self._skip_build: if not self._skip_build:
with static.program_guard(serial_main_prog, serial_startup_prog), \ with static.program_guard(
utils.unique_name.guard(): serial_main_prog, serial_startup_prog
outputs = to_list(self._model(*inputs)) ), utils.unique_name.guard():
if mode != "predict" and self._loss: self._inputs = [
losses = to_list(self._loss(*(outputs + labels))) s._create_feed_layer() for s in self._inputs_spec
self._losses = losses ]
self._labels = [
s._create_feed_layer() for s in self._labels_spec
]
outputs = to_list(self._model(*self._inputs))
if mode != "predict" and (outputs or labels): if mode != "predict" and self._loss:
assert isinstance(
self._loss, paddle.nn.Layer
) or callable(
self._loss
), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
self._losses = to_list(
self._loss(*(outputs + self._labels))
)
if mode != "predict" and (outputs or self._labels):
for metric in self._metrics: for metric in self._metrics:
metrics.append( metrics.append(
to_list(metric.compute(*(outputs + labels)))) to_list(
metric.compute(*(outputs + self._labels))
)
)
else: else:
losses = to_list(self._loss) assert isinstance(
self.losses = losses self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable."
self._losses = to_list(self._loss)
default_ctx = get_default_distributed_context() default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation: if not default_ctx.has_annotation:
...@@ -527,12 +594,12 @@ class Engine: ...@@ -527,12 +594,12 @@ class Engine:
new_process_group(list(range(self._nranks))) new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True default_ctx.data_parallel = True
feed_vars = {"inputs": inputs, "labels": labels} feed_vars = {"inputs": self._inputs, "labels": self._labels}
fetch_vars = { fetch_vars = {
"outputs": flatten(outputs), "outputs": flatten(outputs),
"loss": losses, "loss": self._losses,
"metrics": metrics "metrics": metrics,
} }
if mode != "train": if mode != "train":
...@@ -540,9 +607,27 @@ class Engine: ...@@ -540,9 +607,27 @@ class Engine:
self._set_recompute_ckpts() self._set_recompute_ckpts()
self._dist_contexts[mode] = DistributedContext( self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses, serial_main_prog,
feed_vars, fetch_vars, self._cluster, self._strategy) serial_startup_prog,
self._optimizer,
self._losses,
feed_vars,
fetch_vars,
self._cluster,
self._strategy,
)
self._fwd_dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
self._optimizer,
self._losses,
feed_vars,
fetch_vars,
self._cluster,
self._strategy,
)
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._fwd_main_progs[mode] = serial_main_prog.clone()
def _optimization_tuning(self, mode, dataset, batch_size): def _optimization_tuning(self, mode, dataset, batch_size):
if not self._tuning.enable: if not self._tuning.enable:
...@@ -558,20 +643,24 @@ class Engine: ...@@ -558,20 +643,24 @@ class Engine:
dataset.dp_rank = self._dp_ranks dataset.dp_rank = self._dp_ranks
from .tuner.optimization_tuner import OptimizationTuner from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
self._dist_contexts[mode], self._optimization_tuner = OptimizationTuner(
dataset, self._tuning.to_dict(),
self._inputs_spec, self._dist_contexts[mode],
self._labels_spec, dataset,
batch_size=batch_size, self._inputs_spec,
rank=self._cur_rank) self._labels_spec,
batch_size=batch_size,
rank=self._cur_rank,
)
self._optimization_tuner.tune() self._optimization_tuner.tune()
if self._tuning.run_after_tuning: if self._tuning.run_after_tuning:
# update the strategy # update the strategy
self._dist_contexts[ self._dist_contexts[
mode]._strategy = self._optimization_tuner.get_best_config() mode
]._strategy = self._optimization_tuner.get_best_config()
def _plan(self, mode): def _plan(self, mode):
if self._planned_mode is None: if self._planned_mode is None:
...@@ -595,8 +684,9 @@ class Engine: ...@@ -595,8 +684,9 @@ class Engine:
self._dp_world_sizes = [] self._dp_world_sizes = []
self._dp_ranks = [] self._dp_ranks = []
for feed_var in feed_list: for feed_var in feed_list:
dp_world_size, dp_rank = self._get_input_split_info( dp_world_size, dp_rank = get_input_split_info(
feed_var, self._dist_contexts[mode]) self._cur_rank, feed_var, self._dist_contexts[mode]
)
self._dp_world_sizes.append(dp_world_size) self._dp_world_sizes.append(dp_world_size)
self._dp_ranks.append(dp_rank) self._dp_ranks.append(dp_rank)
...@@ -604,8 +694,9 @@ class Engine: ...@@ -604,8 +694,9 @@ class Engine:
# Parallelize program based on the planner's results # Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner, # For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update. # because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, self._planners[mode].completer, parallelizer = Parallelizer(
self._dist_contexts[mode]) mode, self._planners[mode].completer, self._dist_contexts[mode]
)
if not all_ranks: if not all_ranks:
parallelizer.parallel(self._cur_rank) parallelizer.parallel(self._cur_rank)
else: else:
...@@ -623,22 +714,30 @@ class Engine: ...@@ -623,22 +714,30 @@ class Engine:
for ib, block in enumerate(origin_main_prog.blocks): for ib, block in enumerate(origin_main_prog.blocks):
for iop, op in enumerate(block.ops): for iop, op in enumerate(block.ops):
ref_op = ref_blocks[ib].ops[iop] ref_op = ref_blocks[ib].ops[iop]
assert op.type == ref_op.type, \ assert (
"'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type) op.type == ref_op.type
ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program( ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
ref_op) mode, op.type, ref_mode, ref_op.type
)
ref_op_dist_attr = (
ref_dist_context.get_op_dist_attr_for_program(ref_op)
)
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode): def _initialize(self, mode):
# Get the current content from the distributed context # Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[ self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program mode
].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[ self._serial_startup_progs[mode] = self._dist_contexts[
mode].serial_startup_program mode
].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[ self._dist_main_progs[mode] = self._dist_contexts[
mode].dist_main_programs mode
].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[ self._dist_startup_progs[mode] = self._dist_contexts[
mode].dist_startup_programs mode
].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode]._serial_optimizer self._optimizer = self._dist_contexts[mode]._serial_optimizer
...@@ -684,30 +783,33 @@ class Engine: ...@@ -684,30 +783,33 @@ class Engine:
self._executor.run(prune_startup_prog) self._executor.run(prune_startup_prog)
if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"): if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
self._set_state_dict(mode, self._strict, self._state_dict, self._set_state_dict(
self._dist_attr) mode, self._strict, self._state_dict, self._dist_attr
)
if self._strategy.reinit: if self._strategy.reinit:
self._logger.info("NOTE: parameters will be re-initialized.") self._logger.info("NOTE: parameters will be re-initialized.")
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
self._executor.run(dist_startup_prog) self._executor.run(dist_startup_prog)
def fit(self, def fit(
train_data, self,
train_sample_split=None, train_data,
batch_size=1, train_sample_split=None,
epochs=1, batch_size=1,
steps_per_epoch=None, epochs=1,
log_freq=10, steps_per_epoch=None,
save_dir=None, log_freq=10,
save_freq=1, save_dir=None,
valid_data=None, save_freq=1,
valid_sample_split=None, valid_data=None,
valid_freq=1, valid_sample_split=None,
valid_steps=None, valid_freq=1,
collate_fn=None, valid_steps=None,
callbacks=None, collate_fn=None,
verbose=2): callbacks=None,
verbose=2,
):
""" """
Trains the model for a fixed number of epochs. If `valid_data` is set, Trains the model for a fixed number of epochs. If `valid_data` is set,
evaluation will be done at the end of each epoch. evaluation will be done at the end of each epoch.
...@@ -776,17 +878,13 @@ class Engine: ...@@ -776,17 +878,13 @@ class Engine:
""" """
self._mode = 'train' self._mode = 'train'
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size) train_data, train_sample_split, batch_size
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
assert self._mode in self._dist_main_progs, \
"train model is not ready, please call `engine._prepare_program('train')` first."
train_dataloader = self._prepare_dataloader_from_generator( train_dataloader = self._prepare_dataloader_from_generator(
dataset=train_data, dataset=train_data,
capacity=70, capacity=70,
...@@ -794,7 +892,8 @@ class Engine: ...@@ -794,7 +892,8 @@ class Engine:
batch_size=batch_size, batch_size=batch_size,
epochs=epochs, epochs=epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn) collate_fn=collate_fn,
)
fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
...@@ -823,21 +922,35 @@ class Engine: ...@@ -823,21 +922,35 @@ class Engine:
self.main_program, self.main_program,
fetch_list=fetch_names, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy,
)
except core.EOFException: except core.EOFException:
break break
lr = get_lr(self._optimizer) lr = get_lr(self._optimizer)
logs = self._prepare_logger(outs, epoch, step, lr, fetch_names, logs = self._prepare_logger(
fetch_indices, self._mode) outs,
epoch,
step,
lr,
fetch_names,
fetch_indices,
self._mode,
)
cbks.on_batch_end('train', step, logs) cbks.on_batch_end('train', step, logs)
if valid_data and (epoch + 1) % valid_freq == 0: if valid_data and (epoch + 1) % valid_freq == 0:
val_logs = self.evaluate(valid_data, valid_sample_split, val_logs = self.evaluate(
batch_size, valid_steps, log_freq, valid_data,
collate_fn, callbacks, verbose) valid_sample_split,
batch_size,
valid_steps,
log_freq,
collate_fn,
callbacks,
verbose,
)
val_logs = { val_logs = {
"val_" + name: val "val_" + name: val for name, val in val_logs.items()
for name, val in val_logs.items()
} }
logs.update(val_logs) logs.update(val_logs)
self._switch_mode("train") self._switch_mode("train")
...@@ -849,15 +962,17 @@ class Engine: ...@@ -849,15 +962,17 @@ class Engine:
cbks.on_end('train', logs) cbks.on_end('train', logs)
return self.history return self.history
def evaluate(self, def evaluate(
valid_data, self,
valid_sample_split=None, valid_data,
batch_size=1, valid_sample_split=None,
steps=None, batch_size=1,
log_freq=10, steps=None,
collate_fn=None, log_freq=10,
callbacks=None, collate_fn=None,
verbose=2): callbacks=None,
verbose=2,
):
""" """
Evaluate the loss and metrics of the model on evaluation data. Evaluate the loss and metrics of the model on evaluation data.
...@@ -906,23 +1021,21 @@ class Engine: ...@@ -906,23 +1021,21 @@ class Engine:
""" """
self._mode = 'eval' self._mode = 'eval'
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size) valid_data, valid_sample_split, batch_size
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
assert self._mode in self._dist_main_progs, \
"eval model is not ready, please call `engine._prepare_program('eval')` first."
valid_dataloader = self._prepare_dataloader_from_generator( valid_dataloader = self._prepare_dataloader_from_generator(
dataset=valid_data, dataset=valid_data,
capacity=70, capacity=70,
iterable=False, iterable=False,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn,
)
fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
...@@ -936,10 +1049,9 @@ class Engine: ...@@ -936,10 +1049,9 @@ class Engine:
) )
eval_steps = valid_dataloader._steps eval_steps = valid_dataloader._steps
cbks.on_begin('eval', { cbks.on_begin(
'steps': eval_steps, 'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
'metrics': self._metrics_name() )
})
logs = {} logs = {}
for step, _ in enumerate(valid_dataloader): for step, _ in enumerate(valid_dataloader):
cbks.on_batch_begin('eval', step, logs) cbks.on_batch_begin('eval', step, logs)
...@@ -948,24 +1060,28 @@ class Engine: ...@@ -948,24 +1060,28 @@ class Engine:
self.main_program, self.main_program,
fetch_list=fetch_names, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy,
)
except core.EOFException: except core.EOFException:
break break
logs = self._prepare_logger(outs, None, step, None, fetch_names, logs = self._prepare_logger(
fetch_indices, self._mode) outs, None, step, None, fetch_names, fetch_indices, self._mode
)
cbks.on_batch_end('eval', step, logs) cbks.on_batch_end('eval', step, logs)
cbks.on_end('eval', logs) cbks.on_end('eval', logs)
self._reset_metrics() self._reset_metrics()
return logs return logs
def predict(self, def predict(
test_data, self,
test_sample_split=None, test_data,
batch_size=1, test_sample_split=None,
steps=None, batch_size=1,
collate_fn=None, steps=None,
callbacks=None, collate_fn=None,
verbose=2): callbacks=None,
verbose=2,
):
""" """
Compute the output predictions on testing data. Compute the output predictions on testing data.
...@@ -1011,24 +1127,21 @@ class Engine: ...@@ -1011,24 +1127,21 @@ class Engine:
""" """
self._mode = 'predict' self._mode = 'predict'
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size) test_data, test_sample_split, batch_size
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
assert self._mode in self._dist_main_progs, \
"predict model is not ready, please call `engine._prepare_program('predict')` first."
test_dataloader = self._prepare_dataloader_from_generator( test_dataloader = self._prepare_dataloader_from_generator(
dataset=test_data, dataset=test_data,
capacity=70, capacity=70,
iterable=False, iterable=False,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn,
)
fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
...@@ -1044,41 +1157,45 @@ class Engine: ...@@ -1044,41 +1157,45 @@ class Engine:
self.main_program, self.main_program,
fetch_list=fetch_names, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy,
)
except core.EOFException: except core.EOFException:
break break
logs = self._prepare_logger(outs, None, step, None, fetch_names, logs = self._prepare_logger(
fetch_indices, self._mode) outs, None, step, None, fetch_names, fetch_indices, self._mode
)
cbks.on_batch_end('predict', step, logs) cbks.on_batch_end('predict', step, logs)
outputs.append(list(logs["outputs"].values())) outputs.append(list(logs["outputs"].values()))
cbks.on_end('predict', logs) cbks.on_end('predict', logs)
return outputs return outputs
def dataloader(self, def dataloader(
dataset, self,
batch_size=1, dataset,
shuffle=False, batch_size=1,
drop_last=False, shuffle=False,
collate_fn=None, drop_last=False,
num_workers=0, collate_fn=None,
use_buffer_reader=True, num_workers=0,
use_shared_memory=True, use_buffer_reader=True,
timeout=0, use_shared_memory=True,
worker_init_fn=None, timeout=0,
epochs=1, worker_init_fn=None,
steps_per_epoch=None, epochs=1,
sample_split=1, steps_per_epoch=None,
mode=None): sample_split=1,
mode=None,
):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size) dataset, sample_split, batch_size
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
dataloader = self._prepare_dataloader( dataloader = self._prepare_dataloader(
dataset, dataset,
return_list=False, return_list=False,
...@@ -1092,32 +1209,35 @@ class Engine: ...@@ -1092,32 +1209,35 @@ class Engine:
timeout=timeout, timeout=timeout,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
epochs=epochs, epochs=epochs,
steps_per_epoch=steps_per_epoch) steps_per_epoch=steps_per_epoch,
)
return dataloader return dataloader
def dataloader_from_generator(self, def dataloader_from_generator(
dataset, self,
capacity=70, dataset,
use_double_buffer=True, capacity=70,
iterable=True, use_double_buffer=True,
use_multiprocess=False, iterable=True,
drop_last=True, use_multiprocess=False,
batch_size=1, drop_last=True,
epochs=1, batch_size=1,
steps_per_epoch=None, epochs=1,
collate_fn=None, steps_per_epoch=None,
sample_split=1, collate_fn=None,
mode=None): sample_split=1,
mode=None,
):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size) dataset, sample_split, batch_size
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
dataloader = self._prepare_dataloader_from_generator( dataloader = self._prepare_dataloader_from_generator(
dataset=dataset, dataset=dataset,
capacity=capacity, capacity=capacity,
...@@ -1129,95 +1249,114 @@ class Engine: ...@@ -1129,95 +1249,114 @@ class Engine:
batch_size=batch_size, batch_size=batch_size,
epochs=epochs, epochs=epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn) collate_fn=collate_fn,
)
return dataloader return dataloader
def prepare(self, def prepare(
inputs_spec=None, self,
labels_spec=None, inputs_spec=None,
inputs=None, labels_spec=None,
labels=None, inputs=None,
main_program=None, labels=None,
startup_program=None, main_program=None,
mode=None): startup_program=None,
mode=None,
):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
if not self._mode:
raise ValueError(
"Please set mode to be prepared with `prepare(mode=...)`"
)
if self._has_prepared[self._mode]:
return
inputs_spec = self._validate_spec(inputs_spec)
labels_spec = self._validate_spec(labels_spec)
inputs = self._validate_vars(inputs)
labels = self._validate_vars(labels)
self._orig_main_prog = main_program
self._orig_startup_prog = startup_program
if inputs or labels: if inputs or labels:
self._skip_build = True self._skip_build = True
self._inputs_spec = inputs_spec inputs, labels = self._prepare_data_tensor(
self._labels_spec = labels_spec inputs_spec, labels_spec, inputs, labels
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec, inputs, labels)
self._orig_main_prog = main_program
if self._orig_main_prog is None: if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = startup_program
if self._orig_startup_prog is None: if self._orig_startup_prog is None:
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
elif inputs_spec or labels_spec: elif inputs_spec or labels_spec:
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._outside_dataloader = True self._outside_dataloader = True
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
self._orig_main_prog = main_program
if self._orig_main_prog is None: if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = startup_program
if self._orig_startup_prog is None: if self._orig_startup_prog is None:
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
else: else:
assert self._inputs_spec and self._labels_spec, \ assert (
"Please call the dataloader(...) before calling prepare(...)" self._inputs_spec and self._labels_spec
), "Please call the dataloader(...) before calling prepare(...)"
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = inputs, labels
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
def run(self, data=None, feed=None, fetch_list=None, mode=None): def run(self, data=None, feed=None, fetch_list=None, mode=None):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
feed_dict = self._prepare_feed(data, feed, self._mode) feed_dict = self._prepare_feed(data, feed, self._mode)
fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode) fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
if self._outside_dataloader and not self._has_prepared_reader[ if (
self._mode]: self._outside_dataloader
and not self._has_prepared_reader[self._mode]
):
self._prepare_reader() self._prepare_reader()
outs = self._executor.run(self.main_program, outs = self._executor.run(
feed=feed_dict, self.main_program,
fetch_list=fetch_names, feed=feed_dict,
use_program_cache=self._strategy.use_cache, fetch_list=fetch_names,
return_numpy=self._strategy.return_numpy) use_program_cache=self._strategy.use_cache,
logs = self._prepare_logger(outs, None, None, None, fetch_names, return_numpy=self._strategy.return_numpy,
fetch_indices, self._mode) )
logs = self._prepare_logger(
outs, None, None, None, fetch_names, fetch_indices, self._mode
)
return logs return logs
def _prepare_dataloader(self, def _prepare_dataloader(
dataset, self,
return_list=True, dataset,
batch_size=1, return_list=True,
shuffle=False, batch_size=1,
drop_last=False, shuffle=False,
collate_fn=None, drop_last=False,
num_workers=0, collate_fn=None,
use_buffer_reader=True, num_workers=0,
use_shared_memory=True, use_buffer_reader=True,
timeout=0, use_shared_memory=True,
worker_init_fn=None, timeout=0,
epochs=1, worker_init_fn=None,
steps_per_epoch=None): epochs=1,
steps_per_epoch=None,
):
if self._strategy.gradient_merge and batch_size is not None: if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \ assert (
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape. # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
...@@ -1256,31 +1395,36 @@ class Engine: ...@@ -1256,31 +1395,36 @@ class Engine:
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
split_data=self._strategy.split_data, split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks) data_parallel_rank=self._dp_ranks,
)
return dataloader return dataloader
def _prepare_dataloader_from_generator(self, def _prepare_dataloader_from_generator(
dataset, self,
capacity=None, dataset,
use_double_buffer=True, capacity=None,
iterable=True, use_double_buffer=True,
return_list=False, iterable=True,
use_multiprocess=False, return_list=False,
drop_last=True, use_multiprocess=False,
batch_size=1, drop_last=True,
epochs=1, batch_size=1,
steps_per_epoch=None, epochs=1,
collate_fn=None): steps_per_epoch=None,
collate_fn=None,
):
if self._strategy.gradient_merge and batch_size is not None: if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \ assert (
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape. # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
...@@ -1316,16 +1460,16 @@ class Engine: ...@@ -1316,16 +1460,16 @@ class Engine:
collate_fn=collate_fn, collate_fn=collate_fn,
split_data=self._strategy.split_data, split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks) data_parallel_rank=self._dp_ranks,
)
self._prepare_reader() self._prepare_reader()
return dataloader return dataloader
def _tune(self, tune_data, tune_sample_split=None, batch_size=1): def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self._mode = 'train' self._mode = 'train'
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
tune_data, tune_sample_split, batch_size) tune_data, tune_sample_split, batch_size
self._inputs, self._labels = self._prepare_data_tensor( )
self._inputs_spec, self._labels_spec)
self._optimization_tuning(self._mode, tune_data, batch_size) self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_spec(self, specs): def _validate_spec(self, specs):
...@@ -1333,46 +1477,39 @@ class Engine: ...@@ -1333,46 +1477,39 @@ class Engine:
self._k_steps = self._strategy.gradient_merge.k_steps self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec) if not isinstance(spec, InputSpec):
raise TypeError(
"'spec' must be object of class `paddle.static.InputSpec`."
)
if spec.name is None: if spec.name is None:
raise ValueError( raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}." "Requires Input[{}].name != None, but receive `None` with {}.".format(
.format(i, spec)) i, spec
)
)
if self._k_steps > 1: if self._k_steps > 1:
shape = list(spec.shape) shape = list(spec.shape)
assert shape[0] % self._k_steps == 0, \ assert (
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps) shape[0] % self._k_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps
)
shape[0] //= self._k_steps shape[0] //= self._k_steps
spec.shape = shape spec.shape = shape
return specs return specs or []
def _validate_vars(self, vars):
vars = to_list(vars)
if vars is not None:
for i, var in enumerate(vars):
if not isinstance(var, Variable):
raise TypeError("'var' must be a `Variable`.")
return vars or []
def _is_local_var(self, var): def _is_local_var(self, var):
var_name = _to_name_str(var) var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars return var_name in self.main_program.global_block().vars
def _get_input_split_info(self, var, dist_context):
# deduce how the input data is split among the cluster
from .utils import _get_comm_group, _get_corresponding_rank
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if self._cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh,
self._cur_rank)
else:
rank_id = self._cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
return len(group_ranks), group_ranks.index(rank_id)
return 1, 0
def _set_recompute_ckpts(self): def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3 # NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here # TODO support more PaddleNLP/CV models here
...@@ -1381,10 +1518,12 @@ class Engine: ...@@ -1381,10 +1518,12 @@ class Engine:
# extract ckpts by specific model # extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer): if isinstance(self._model, paddle.nn.Layer):
if hasattr(self._model, if hasattr(
"gpt") and self._model.__class__.__name__ in [ self._model, "gpt"
'GPTForPretraining', 'GPTForPretrainingAuto' ) and self._model.__class__.__name__ in [
]: 'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = self._model.gpt.checkpoints exact_ckpts = self._model.gpt.checkpoints
else: else:
exact_ckpts = recompute.checkpoints exact_ckpts = recompute.checkpoints
...@@ -1396,16 +1535,10 @@ class Engine: ...@@ -1396,16 +1535,10 @@ class Engine:
recompute.checkpoints = exact_ckpts[:] recompute.checkpoints = exact_ckpts[:]
logs = { logs = {
'Model Class': self._model.__class__.__name__, 'Model Class': self._model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts 'Applied Recompute ckpts': exact_ckpts,
} }
self._logger.info(logs) self._logger.info(logs)
def _validate_opt(self, optimizer):
if optimizer is not None:
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
def _reset_metrics(self): def _reset_metrics(self):
for metric in self._metrics: for metric in self._metrics:
metric.reset() metric.reset()
...@@ -1417,12 +1550,18 @@ class Engine: ...@@ -1417,12 +1550,18 @@ class Engine:
return metrics_name return metrics_name
def _switch_mode(self, mode): def _switch_mode(self, mode):
assert (
mode in self._dist_main_progs
), "{} model is not ready, please call `prepare()` first.".format(mode)
self.to_mode(mode) self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer self._optimizer = self._dist_contexts[mode]._serial_optimizer
def to_mode(self, mode): def to_mode(self, mode):
assert mode in ["train", "eval", "predict"], \ assert mode in [
"mode {} should be one of ['train', 'eval', 'predict']".format(mode) "train",
"eval",
"predict",
], "mode {} should be one of ['train', 'eval', 'predict']".format(mode)
self._mode = mode self._mode = mode
def _set_state_dict(self, mode, strict, state_dict, dist_attr): def _set_state_dict(self, mode, strict, state_dict, dist_attr):
...@@ -1483,20 +1622,24 @@ class Engine: ...@@ -1483,20 +1622,24 @@ class Engine:
serial_program = self._serial_main_progs[self._mode] serial_program = self._serial_main_progs[self._mode]
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
self._saver.save(path, self._saver.save(
serial_program=serial_program, path,
dist_main_program=dist_main_prog, serial_program=serial_program,
dist_context=dist_context) dist_main_program=dist_main_prog,
dist_context=dist_context,
)
else: else:
assert "predict" in self._dist_main_progs assert "predict" in self._dist_main_progs
feed_vars = self._feed_vars["predict"]['inputs'] feed_vars = self._feed_vars["predict"]['inputs']
fetch_vars = self._fetch_vars["predict"]['outputs'] fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank] dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
self._saver.save_inference_model(path, self._saver.save_inference_model(
feed_vars, path,
fetch_vars, feed_vars,
self._executor, fetch_vars,
program=dist_main_prog) self._executor,
program=dist_main_prog,
)
def load(self, path, strict=True, load_optimizer=True): def load(self, path, strict=True, load_optimizer=True):
""" """
...@@ -1508,10 +1651,10 @@ class Engine: ...@@ -1508,10 +1651,10 @@ class Engine:
strict (bool, optional): Whether to skip the loading of mismatch strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a the parameter in file storing model states of or receives a
mismatch shape). Default: False. mismatch shape). Default: True.
load_optimizer (bool, optional): If True, the stored optimizer load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is initialized states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False. from scratch. Default: True.
Returns: Returns:
None None
...@@ -1546,10 +1689,11 @@ class Engine: ...@@ -1546,10 +1689,11 @@ class Engine:
""" """
self._strict = strict self._strict = strict
self._state_dict, self._dist_attr = self._saver.load( self._state_dict, self._dist_attr = self._saver.load(
path, load_optimizer) path, load_optimizer
)
return self._state_dict, self._dist_attr return self._state_dict, self._dist_attr
def cost(self, inputs_spec=None, labels_spec=None, mode="train"): def cost(self, inputs_spec=None, labels_spec=None, mode=None):
""" """
Get and Print cost, including memory of every rank, Get and Print cost, including memory of every rank,
max memory among all ranks, and the global cost of one step based on max memory among all ranks, and the global cost of one step based on
...@@ -1560,7 +1704,7 @@ class Engine: ...@@ -1560,7 +1704,7 @@ class Engine:
Args: Args:
inputs_spec(InputSpec): The specification of inputs. Default: None. inputs_spec(InputSpec): The specification of inputs. Default: None.
labels_spec(InputSpec): The specification of labels. Default: None. labels_spec(InputSpec): The specification of labels. Default: None.
mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: "train". mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
Returns: Returns:
Return the global execution time (ms) and max memory (B). Return the global execution time (ms) and max memory (B).
...@@ -1568,29 +1712,44 @@ class Engine: ...@@ -1568,29 +1712,44 @@ class Engine:
""" """
# Check parallel mode # Check parallel mode
if self._strategy.auto_mode == "full": if self._strategy.auto_mode == "full":
print( self._logger.info(
"The cost will be calcudated in the search process when the auto mode is full." "The cost will be calcudated in the search process when the auto mode is full."
) )
return return
# Check mode # Check mode
accepted_modes = ["train", "predict", "eval"] mode = mode if mode is not None else self._mode
if mode not in accepted_modes: assert mode is not None, "Please set mode."
raise ValueError("The mode {} is not in accepted modes {}".format( if mode not in self._has_prepared:
mode, accepted_modes)) raise ValueError(
"The mode {} is not in accepted modes {}".format(
mode, list(self._has_prepared.keys())
)
)
self.to_mode(mode) self.to_mode(mode)
if inputs_spec is not None: if inputs_spec is not None and not self._has_prepared[mode]:
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec self._inputs_spec = self._validate_spec(inputs_spec)
self._inputs, self._labels = self._prepare_data_tensor( self._labels_spec = self._validate_spec(labels_spec)
self._inputs_spec, self._labels_spec)
self._build(mode) self._build(mode)
self._plan(mode) self._plan(mode)
else: else:
if _non_static_mode() or self._dygraph_mode: if _non_static_mode() or self._dygraph_mode:
raise ValueError( raise ValueError(
"Please call `engine._prepare_program('mode')` firstly when in the static graph mode." "Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`."
) )
else:
self._logger.info(
"The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
)
program = paddle.static.default_main_program()
if (
not program.global_block().ops
or not program.global_block().ops
) and not self._has_prepared[mode]:
raise ValueError(
"Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`."
)
# Estimate the exec cost and max memory # Estimate the exec cost and max memory
global_cost, max_memory = get_cost_from_engine(self, mode) global_cost, max_memory = get_cost_from_engine(self, mode)
......
...@@ -23,12 +23,18 @@ from functools import reduce ...@@ -23,12 +23,18 @@ from functools import reduce
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.process_group import get_all_process_groups from paddle.distributed.auto_parallel.process_group import (
get_all_process_groups,
)
from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import (
TensorDistributedAttribute,
OperatorDistributedAttribute,
)
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
] ]
...@@ -39,7 +45,8 @@ def get_logger(log_level, name="auto_parallel"): ...@@ -39,7 +45,8 @@ def get_logger(log_level, name="auto_parallel"):
logger.setLevel(log_level) logger.setLevel(log_level)
log_handler = logging.StreamHandler() log_handler = logging.StreamHandler()
log_format = logging.Formatter( log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format) log_handler.setFormatter(log_format)
logger.addHandler(log_handler) logger.addHandler(log_handler)
return logger return logger
...@@ -114,8 +121,11 @@ def verify_shard_spec(shard_spec, tensor_shape, process_mesh): ...@@ -114,8 +121,11 @@ def verify_shard_spec(shard_spec, tensor_shape, process_mesh):
if not verify_dims_mapping(dims_mapping, process_mesh): if not verify_dims_mapping(dims_mapping, process_mesh):
return False return False
for i in range(len(tensor_shape)): for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ if (
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0: dims_mapping[i] != -1
and tensor_shape[i] > 0
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0
):
return False return False
return True return True
...@@ -141,14 +151,17 @@ def compute_compatible_dims_mapping(dims_mapping_list): ...@@ -141,14 +151,17 @@ def compute_compatible_dims_mapping(dims_mapping_list):
return None return None
length = len(dims_mapping_list[0]) length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list: for dims_mapping in dims_mapping_list:
assert dims_mapping is not None, \ assert (
"Dims mapping must not be None for compatible computation" dims_mapping is not None
assert len(dims_mapping) == length, \ ), "Dims mapping must not be None for compatible computation"
"The length of dims_mapping in list must be same for compatible computation." assert (
len(dims_mapping) == length
), "The length of dims_mapping in list must be same for compatible computation."
compatible_result = [] compatible_result = []
for dim_mappings in zip(*dims_mapping_list): for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping( compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings)) list(dim_mappings)
)
if compatible_dim_mapping is None: if compatible_dim_mapping is None:
return None return None
compatible_result.append(compatible_dim_mapping) compatible_result.append(compatible_dim_mapping)
...@@ -161,7 +174,10 @@ def compute_compatible_process_mesh(process_mesh_list): ...@@ -161,7 +174,10 @@ def compute_compatible_process_mesh(process_mesh_list):
return compatible_process_mesh return compatible_process_mesh
for process_mesh in process_mesh_list: for process_mesh in process_mesh_list:
if process_mesh is not None: if process_mesh is not None:
if compatible_process_mesh is None or compatible_process_mesh == process_mesh: if (
compatible_process_mesh is None
or compatible_process_mesh == process_mesh
):
compatible_process_mesh = process_mesh compatible_process_mesh = process_mesh
else: else:
return None return None
...@@ -201,15 +217,18 @@ def remove_distributed_attr_suffix(name): ...@@ -201,15 +217,18 @@ def remove_distributed_attr_suffix(name):
def check_distributed_attr_for_program(program, dist_context=None): def check_distributed_attr_for_program(program, dist_context=None):
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \ assert (
"Distributed attributes must be initialized before check." dist_context.is_initialized_for_program()
), "Distributed attributes must be initialized before check."
for block in program.blocks: for block in program.blocks:
for tensor in block.vars.values(): for tensor in block.vars.values():
dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) dist_tensor = dist_context.get_dist_tensor_for_graph(tensor)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
tensor) tensor
)
if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()):
return False return False
for op in block.ops: for op in block.ops:
...@@ -229,6 +248,7 @@ def print_program_with_dist_attr(program, dist_context=None): ...@@ -229,6 +248,7 @@ def print_program_with_dist_attr(program, dist_context=None):
lock.acquire() lock.acquire()
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
print(program, flush=True) print(program, flush=True)
...@@ -242,7 +262,7 @@ def print_program_with_dist_attr(program, dist_context=None): ...@@ -242,7 +262,7 @@ def print_program_with_dist_attr(program, dist_context=None):
def _get_comm_group(processes, shape, axis, rank): def _get_comm_group(processes, shape, axis, rank):
""" """
Given a rank and the processes mesh the rank belongs to, Given a rank and the processes mesh the rank belongs to,
compute the communication peers of the rank based on the give axis in the mesh. compute the communication peers of the rank based on the give axis in the mesh.
Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2].
...@@ -256,7 +276,8 @@ def _get_comm_group(processes, shape, axis, rank): ...@@ -256,7 +276,8 @@ def _get_comm_group(processes, shape, axis, rank):
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous # tricks to support processes mesh when it is not start with 0 or continuous
assert rank in processes, "rank [{}] is NOT in processes group {}".format( assert rank in processes, "rank [{}] is NOT in processes group {}".format(
rank, processes) rank, processes
)
rank_relatvie = processes.index(rank) rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinate = _linear_idx2coordinate(shape, rank_relatvie)
coordinates_in_group = [coordinate[:] for i in range(shape[axis])] coordinates_in_group = [coordinate[:] for i in range(shape[axis])]
...@@ -276,7 +297,7 @@ def _get_comm_group(processes, shape, axis, rank): ...@@ -276,7 +297,7 @@ def _get_comm_group(processes, shape, axis, rank):
def _get_idx_in_axis(processes, shape, axis, rank): def _get_idx_in_axis(processes, shape, axis, rank):
""" """
Given a rank and the processes mesh the rank belongs to, Given a rank and the processes mesh the rank belongs to,
compute the index of the rank in given axis. compute the index of the rank in given axis.
Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3].
...@@ -297,20 +318,20 @@ def _coordinate2linear_idx(mesh_shape, coordinate): ...@@ -297,20 +318,20 @@ def _coordinate2linear_idx(mesh_shape, coordinate):
""" """
convert a coordinate in multidimensional mesh space into a scala idx in linear space. convert a coordinate in multidimensional mesh space into a scala idx in linear space.
it use Row-major order for dimension conversion. it use Row-major order for dimension conversion.
so it has: [most_significant_dim, ..., least_significant_dim] so it has: [most_significant_dim, ..., least_significant_dim]
assume: assume:
the size of i-th dimension to be: S[i] the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j] the index of j-th dimension is: I[j]
linear_idx of a n dimensional coordinate is: linear_idx of a n dimensional coordinate is:
I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) +
I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) +
I[n-3] * ( S[n-4] * .... S[0]) + I[n-3] * ( S[n-4] * .... S[0]) +
... ...
I[1] * ( S[0]) + I[1] * ( S[0]) +
I[0] I[0]
""" """
...@@ -325,14 +346,19 @@ def _coordinate2linear_idx(mesh_shape, coordinate): ...@@ -325,14 +346,19 @@ def _coordinate2linear_idx(mesh_shape, coordinate):
assert len(mesh_shape) == len( assert len(mesh_shape) == len(
coordinate coordinate
), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format(
mesh_shape, coordinate) mesh_shape, coordinate
)
for i in range(len(mesh_shape)): for i in range(len(mesh_shape)):
assert coordinate[ assert (
i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( coordinate[i] >= 0
i, coordinate) ), "index in dimension [{}] is least than zero. coordinate: {}".format(
assert coordinate[i] < mesh_shape[ i, coordinate
i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( )
i, mesh_shape, coordinate) assert (
coordinate[i] < mesh_shape[i]
), "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format(
i, mesh_shape, coordinate
)
base = mesh_shape[-1] base = mesh_shape[-1]
linear_idx = coordinate[-1] linear_idx = coordinate[-1]
...@@ -350,7 +376,7 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): ...@@ -350,7 +376,7 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
mapping a linear scala into multidimensional mesh space, return it coordinate in that space. mapping a linear scala into multidimensional mesh space, return it coordinate in that space.
it is the inverse function of _coordinate2linear_idx. it is the inverse function of _coordinate2linear_idx.
assume: assume:
the size of i-th dimension to be: S[i] the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j] the index of j-th dimension is: I[j]
...@@ -365,11 +391,13 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): ...@@ -365,11 +391,13 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
""" """
assert linear_idx >= 0, "linear index [{}] is least than zero".format( assert linear_idx >= 0, "linear index [{}] is least than zero".format(
linear_idx) linear_idx
)
assert linear_idx < np.prod( assert linear_idx < np.prod(
mesh_shape mesh_shape
), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format(
mesh_shape, linear_idx) mesh_shape, linear_idx
)
base = 1 base = 1
coordinate = [-1] * len(mesh_shape) coordinate = [-1] * len(mesh_shape)
...@@ -392,15 +420,17 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): ...@@ -392,15 +420,17 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
coordinate = None coordinate = None
for mesh in dist_context.process_meshes: for mesh in dist_context.process_meshes:
if rank in mesh.processes and mesh.topology == target_mesh.topology: if rank in mesh.processes and mesh.topology == target_mesh.topology:
coordinate = _linear_idx2coordinate(mesh.topology, coordinate = _linear_idx2coordinate(
mesh.processes.index(rank)) mesh.topology, mesh.processes.index(rank)
)
break break
# assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
# rank) # rank)
if coordinate is not None: if coordinate is not None:
return target_mesh.processes[_coordinate2linear_idx( return target_mesh.processes[
mesh.topology, coordinate)] _coordinate2linear_idx(mesh.topology, coordinate)
]
else: else:
return target_mesh.processes[0] return target_mesh.processes[0]
...@@ -412,7 +442,8 @@ def _get_unshard_dist_shape(var, dist_attr): ...@@ -412,7 +442,8 @@ def _get_unshard_dist_shape(var, dist_attr):
assert len(var_shape) == len( assert len(var_shape) == len(
mapping mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
var_shape, mapping) var_shape, mapping
)
new_shape = [] new_shape = []
for idx in range(len(var_shape)): for idx in range(len(var_shape)):
if var_shape[idx] == -1 or mapping[idx] == -1: if var_shape[idx] == -1 or mapping[idx] == -1:
...@@ -425,13 +456,15 @@ def _get_unshard_dist_shape(var, dist_attr): ...@@ -425,13 +456,15 @@ def _get_unshard_dist_shape(var, dist_attr):
def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None):
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
for var in dist_main_prog.list_vars(): for var in dist_main_prog.list_vars():
if var.is_data: if var.is_data:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var) var
)
inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr)
var.desc.set_shape(inverse_shape) var.desc.set_shape(inverse_shape)
dim_mapping = tensor_dist_attr.dims_mapping dim_mapping = tensor_dist_attr.dims_mapping
...@@ -441,62 +474,76 @@ def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): ...@@ -441,62 +474,76 @@ def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None):
def _update_addition_info(addition_info): def _update_addition_info(addition_info):
""" Update default addition_info with inputs """ """Update default addition_info with inputs"""
add_info = {"epoch": 0, "batch": 0, "batch_size": 0} add_info = {"epoch": 0, "batch": 0, "batch_size": 0}
if not addition_info: if not addition_info:
return add_info return add_info
elif not isinstance(addition_info, dict): elif not isinstance(addition_info, dict):
raise TypeError("The type of 'addition_info' should be 'dict', " raise TypeError(
"but got '{}'.".format(str(type(addition_info)))) "The type of 'addition_info' should be 'dict', "
"but got '{}'.".format(str(type(addition_info)))
)
else: else:
for item, value in addition_info.items(): for item, value in addition_info.items():
if item not in ["epoch", "batch", "batch_size"]: if item not in ["epoch", "batch", "batch_size"]:
raise ValueError( raise ValueError(
"The key of 'addition_info' should be one of the " "The key of 'addition_info' should be one of the "
"['epoch', 'batch', 'batch_size'], but got '{}'.".format( "['epoch', 'batch', 'batch_size'], but got '{}'.".format(
str(item))) str(item)
)
)
if not isinstance(value, int): if not isinstance(value, int):
raise ValueError( raise ValueError(
"The value of 'addition_info' should be 'int', " "The value of 'addition_info' should be 'int', "
"but got '{}'.".format(str(type(value)))) "but got '{}'.".format(str(type(value)))
)
add_info[item] = value add_info[item] = value
return add_info return add_info
def _check_valid_path(file_path): def _check_valid_path(file_path):
""" Validity check of input file path """ """Validity check of input file path"""
if not file_path: if not file_path:
return file_path return file_path
elif isinstance(file_path, list): elif isinstance(file_path, list):
for file in file_path: for file in file_path:
if not isinstance(file, str): if not isinstance(file, str):
raise TypeError("The type of file path should be 'str', " raise TypeError(
"but got '{}'.".format(str(type(file)))) "The type of file path should be 'str', "
"but got '{}'.".format(str(type(file)))
)
if not os.path.exists(file): if not os.path.exists(file):
raise ValueError( raise ValueError(
"The file path '{}' does not exist.".format(file)) "The file path '{}' does not exist.".format(file)
)
return file_path return file_path
else: else:
raise TypeError("The type of file path should be 'list', " raise TypeError(
"but got '{}'.".format(str(type(file_path)))) "The type of file path should be 'list', "
"but got '{}'.".format(str(type(file_path)))
)
def _check_param_dict(param_dict): def _check_param_dict(param_dict):
if not param_dict: if not param_dict:
raise ValueError("'param_dict' cannot be None.") raise ValueError("'param_dict' cannot be None.")
elif not isinstance(param_dict, dict): elif not isinstance(param_dict, dict):
raise TypeError("The type of 'param_dict' should be 'dict', " raise TypeError(
"but got '{}'.".format(str(type(param_dict)))) "The type of 'param_dict' should be 'dict', "
"but got '{}'.".format(str(type(param_dict)))
)
else: else:
for name, value in param_dict.items(): for name, value in param_dict.items():
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError( raise TypeError(
"The type of key of 'param_dict' should be 'str', " "The type of key of 'param_dict' should be 'str', "
"but got '{}'.".format(str(type(name)))) "but got '{}'.".format(str(type(name)))
)
if not isinstance(value, paddle.fluid.LoDTensor): if not isinstance(value, paddle.fluid.LoDTensor):
raise TypeError( raise TypeError(
"The type of value of 'param_dict' should be 'LoDTensor', " "The type of value of 'param_dict' should be 'LoDTensor', "
"but got '{}'.".format(str(type(value)))) "but got '{}'.".format(str(type(value)))
)
return param_dict return param_dict
...@@ -504,35 +551,42 @@ def _check_dist_attr(dist_attr): ...@@ -504,35 +551,42 @@ def _check_dist_attr(dist_attr):
if not dist_attr: if not dist_attr:
return dist_attr return dist_attr
elif not isinstance(dist_attr, dict): elif not isinstance(dist_attr, dict):
raise TypeError("The type of 'dist_attr' should be 'dict', " raise TypeError(
"but got '{}'.".format(str(type(dist_attr)))) "The type of 'dist_attr' should be 'dict', "
"but got '{}'.".format(str(type(dist_attr)))
)
else: else:
for name, value in dist_attr.items(): for name, value in dist_attr.items():
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError( raise TypeError(
"The type of param name of 'dist_attr' should be 'str', " "The type of param name of 'dist_attr' should be 'str', "
"but got '{}'.".format(str(type(name)))) "but got '{}'.".format(str(type(name)))
)
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError( raise TypeError(
"The type of distributed attribute should be 'dict', " "The type of distributed attribute should be 'dict', "
"but got '{}'".format(str(type(value)))) "but got '{}'".format(str(type(value)))
)
attr = ['process_shape', 'process_group', 'dims_mapping'] attr = ['process_shape', 'process_group', 'dims_mapping']
if list(value.keys()) != attr: if list(value.keys()) != attr:
raise ValueError( raise ValueError(
"The key of distributed attribute should be " "The key of distributed attribute should be "
"'['process_shape', 'process_group', 'dims_mapping']', " "'['process_shape', 'process_group', 'dims_mapping']', "
"but got {}.".format(str(value.keys()))) "but got {}.".format(str(value.keys()))
)
return dist_attr return dist_attr
def save_distributed_checkpoint(program, def save_distributed_checkpoint(
checkpoint_path, program,
dist_attr_path, checkpoint_path,
addition_info=None, dist_attr_path,
is_integrated=False, addition_info=None,
dist_context=None): is_integrated=False,
""" dist_context=None,
Save model parameter state, optimzer state, distributed attribute and ):
"""
Save model parameter state, optimzer state, distributed attribute and
additional information of each rank. additional information of each rank.
Args: Args:
...@@ -569,11 +623,12 @@ def save_distributed_checkpoint(program, ...@@ -569,11 +623,12 @@ def save_distributed_checkpoint(program,
else: else:
# TODO: integrate param before save # TODO: integrate param before save
raise NotImplementedError( raise NotImplementedError(
"Integrating parameter has not been implemented.") "Integrating parameter has not been implemented."
)
def load_distributed_checkpoint(checkpoint_path, dist_attr_path): def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
""" """
Load parameter, optimizer, distributed attribute and addition_info. Load parameter, optimizer, distributed attribute and addition_info.
Args: Args:
...@@ -583,7 +638,7 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): ...@@ -583,7 +638,7 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
Returns: Returns:
param_dict(dict): parameters' value of all ranks. param_dict(dict): parameters' value of all ranks.
dist_attr(dict): parameters' distributed attribute. dist_attr(dict): parameters' distributed attribute.
addition_info(dict): additional information user saved in last training. addition_info(dict): additional information user saved in last training.
Notes: Notes:
The return, 'addition_info', is belonging to the first file of checkpoint_path by default. The return, 'addition_info', is belonging to the first file of checkpoint_path by default.
...@@ -591,16 +646,16 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): ...@@ -591,16 +646,16 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
Examples: Examples:
.. code-block:: python .. code-block:: python
ckpt_path = ['./model_state_rank0.pdmodel', ckpt_path = ['./model_state_rank0.pdmodel',
'./model_state_rank1.pdmodel'] './model_state_rank1.pdmodel']
dist_attr_path = ['./dist_attr_rank0.pdattr', dist_attr_path = ['./dist_attr_rank0.pdattr',
'./dist_attr_rank1.pdattr'] './dist_attr_rank1.pdattr']
param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path)
""" """
assert _check_valid_path(checkpoint_path), \ assert _check_valid_path(
"'checkpoint_path' cannot be None." checkpoint_path
assert _check_valid_path(dist_attr_path), \ ), "'checkpoint_path' cannot be None."
"'dist_attr_path' cannot be None." assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None."
state_dict_info = _load_distributed_state_dict(checkpoint_path) state_dict_info = _load_distributed_state_dict(checkpoint_path)
dist_attr = _load_distributed_attribute(dist_attr_path) dist_attr = _load_distributed_attribute(dist_attr_path)
...@@ -609,11 +664,10 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): ...@@ -609,11 +664,10 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
return param_dict, dist_attr, addition_info return param_dict, dist_attr, addition_info
def load_checkpoint_into_program(checkpoint_path, def load_checkpoint_into_program(
dist_attr_path, checkpoint_path, dist_attr_path, program, dist_context=None
program, ):
dist_context=None): """
"""
Load parameter, optimizer, distributed attribute and addition_info into model. Load parameter, optimizer, distributed attribute and addition_info into model.
Args: Args:
...@@ -624,7 +678,7 @@ def load_checkpoint_into_program(checkpoint_path, ...@@ -624,7 +678,7 @@ def load_checkpoint_into_program(checkpoint_path,
Returns: Returns:
addition_info(dict): user saved in last train. addition_info(dict): user saved in last train.
Notes: Notes:
The return, 'addition_info', is belonging to the first file of checkpoint_path by default. The return, 'addition_info', is belonging to the first file of checkpoint_path by default.
...@@ -632,19 +686,19 @@ def load_checkpoint_into_program(checkpoint_path, ...@@ -632,19 +686,19 @@ def load_checkpoint_into_program(checkpoint_path,
.. code-block:: python .. code-block:: python
exe.run(startup_program) exe.run(startup_program)
ckpt_path = ['./model_state_rank0.pdmodel', ckpt_path = ['./model_state_rank0.pdmodel',
'./model_state_rank1.pdmodel'] './model_state_rank1.pdmodel']
dist_attr_path = ['./dist_attr_rank0.pdattr', dist_attr_path = ['./dist_attr_rank0.pdattr',
'./dist_attr_rank1.pdattr'] './dist_attr_rank1.pdattr']
load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program)
""" """
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
assert isinstance(program, paddle.fluid.framework.Program) assert isinstance(program, paddle.fluid.framework.Program)
assert _check_valid_path(checkpoint_path), \ assert _check_valid_path(
"'checkpoint_path' cannot be None." checkpoint_path
assert _check_valid_path(dist_attr_path), \ ), "'checkpoint_path' cannot be None."
"'dist_attr_path' cannot be None." assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None."
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
all_state_dict_info = _load_distributed_state_dict(checkpoint_path) all_state_dict_info = _load_distributed_state_dict(checkpoint_path)
...@@ -652,16 +706,16 @@ def load_checkpoint_into_program(checkpoint_path, ...@@ -652,16 +706,16 @@ def load_checkpoint_into_program(checkpoint_path,
all_cur_dist_attr = get_dist_attr(program, dist_context) all_cur_dist_attr = get_dist_attr(program, dist_context)
all_param_dict = all_state_dict_info["model"] all_param_dict = all_state_dict_info["model"]
addition_info = all_state_dict_info["addition_info"] addition_info = all_state_dict_info["addition_info"]
sliced_param_dict = merge_and_slice_parameter(all_param_dict, sliced_param_dict = merge_and_slice_parameter(
all_pre_dist_attr, all_param_dict, all_pre_dist_attr, all_cur_dist_attr
all_cur_dist_attr) )
load_parameter_into_program(sliced_param_dict, program) load_parameter_into_program(sliced_param_dict, program)
return addition_info return addition_info
def load_parameter_into_program(param_dict, program): def load_parameter_into_program(param_dict, program):
""" """
Load parameters into program. Load parameters into program.
Args: Args:
...@@ -676,28 +730,31 @@ def load_parameter_into_program(param_dict, program): ...@@ -676,28 +730,31 @@ def load_parameter_into_program(param_dict, program):
def _save_distributed_attribute(program, dist_attr_path, dist_context): def _save_distributed_attribute(program, dist_attr_path, dist_context):
""" Save distributed attribute of all parameters """ """Save distributed attribute of all parameters"""
# TODO: just save a complete distributed attribute file # TODO: just save a complete distributed attribute file
rank_id = paddle.distributed.get_rank() rank_id = paddle.distributed.get_rank()
dist_attr_name = os.path.join(dist_attr_path, dist_attr_name = os.path.join(
"dist_attr_rank{}.pdattr".format(rank_id)) dist_attr_path, "dist_attr_rank{}.pdattr".format(rank_id)
)
dist_attr_dict = { dist_attr_dict = {
"model": get_dist_attr(program, dist_context), "model": get_dist_attr(program, dist_context),
"world_size": paddle.distributed.get_world_size() "world_size": paddle.distributed.get_world_size(),
} }
paddle.save(dist_attr_dict, dist_attr_name) paddle.save(dist_attr_dict, dist_attr_name)
logging.info( logging.info(
"Already saved distributed attribute to '{}'.".format(dist_attr_path)) "Already saved distributed attribute to '{}'.".format(dist_attr_path)
)
def _load_distributed_attribute(dist_attr_path): def _load_distributed_attribute(dist_attr_path):
""" Load parameters' distributed attribute from dist_attr_path """ """Load parameters' distributed attribute from dist_attr_path"""
total_dist_attr = {} total_dist_attr = {}
for dist_attr_file in dist_attr_path: for dist_attr_file in dist_attr_path:
dist_attr = paddle.load(dist_attr_file) dist_attr = paddle.load(dist_attr_file)
pre_world_size = dist_attr["world_size"] pre_world_size = dist_attr["world_size"]
assert pre_world_size == len(dist_attr_path), \ assert pre_world_size == len(
"The number of 'dist_attr_path' must be equal to the last training world size." dist_attr_path
), "The number of 'dist_attr_path' must be equal to the last training world size."
for name, attr in dist_attr["model"].items(): for name, attr in dist_attr["model"].items():
if name not in total_dist_attr: if name not in total_dist_attr:
total_dist_attr[name] = attr total_dist_attr[name] = attr
...@@ -706,27 +763,29 @@ def _load_distributed_attribute(dist_attr_path): ...@@ -706,27 +763,29 @@ def _load_distributed_attribute(dist_attr_path):
def _save_distributed_state_dict(program, addition_info, checkpoint_path): def _save_distributed_state_dict(program, addition_info, checkpoint_path):
""" Save parameters' state_dict """ """Save parameters' state_dict"""
rank = paddle.distributed.get_rank() rank = paddle.distributed.get_rank()
ckpt_file_name = os.path.join(checkpoint_path, ckpt_file_name = os.path.join(
"model_state_rank{}.pdmodel".format(rank)) checkpoint_path, "model_state_rank{}.pdmodel".format(rank)
)
state_dict = { state_dict = {
"model": program.state_dict(), "model": program.state_dict(),
"world_size": paddle.distributed.get_world_size(), "world_size": paddle.distributed.get_world_size(),
"addition_info": addition_info "addition_info": addition_info,
} }
paddle.save(state_dict, ckpt_file_name) paddle.save(state_dict, ckpt_file_name)
logging.info("Already saved model to '{}'.".format(checkpoint_path)) logging.info("Already saved model to '{}'.".format(checkpoint_path))
def _load_distributed_state_dict(checkpoint_path): def _load_distributed_state_dict(checkpoint_path):
""" Load parameters' state_dict from checkpoint_path """ """Load parameters' state_dict from checkpoint_path"""
all_state_dict = {} all_state_dict = {}
for idx, ckpt_file in enumerate(checkpoint_path): for idx, ckpt_file in enumerate(checkpoint_path):
state_dict_info = paddle.load(ckpt_file, return_numpy=True) state_dict_info = paddle.load(ckpt_file, return_numpy=True)
pre_world_size = state_dict_info["world_size"] pre_world_size = state_dict_info["world_size"]
assert pre_world_size == len(checkpoint_path), \ assert pre_world_size == len(
"The number of 'checkpoint_path' must be equal to the last training world size." checkpoint_path
), "The number of 'checkpoint_path' must be equal to the last training world size."
if idx == 0: if idx == 0:
addition_info = state_dict_info["addition_info"] addition_info = state_dict_info["addition_info"]
for name, value in state_dict_info["model"].items(): for name, value in state_dict_info["model"].items():
...@@ -737,13 +796,13 @@ def _load_distributed_state_dict(checkpoint_path): ...@@ -737,13 +796,13 @@ def _load_distributed_state_dict(checkpoint_path):
all_state_dict_info = { all_state_dict_info = {
"model": all_state_dict, "model": all_state_dict,
"addition_info": addition_info "addition_info": addition_info,
} }
return all_state_dict_info return all_state_dict_info
def get_dist_attr(program, dist_context=None): def get_dist_attr(program, dist_context=None):
""" """
Get distributed attribute of current rank. Get distributed attribute of current rank.
Args: Args:
...@@ -758,13 +817,14 @@ def get_dist_attr(program, dist_context=None): ...@@ -758,13 +817,14 @@ def get_dist_attr(program, dist_context=None):
for var in program.list_vars(): for var in program.list_vars():
if is_parameter(var) or is_belong_to_optimizer(var): if is_parameter(var) or is_belong_to_optimizer(var):
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var) var
)
process_mesh = tensor_dist_attr.process_mesh process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping dims_mapping = tensor_dist_attr.dims_mapping
dist_attr[var.name] = { dist_attr[var.name] = {
"process_shape": process_mesh.topology, "process_shape": process_mesh.topology,
"process_group": process_mesh.processes, "process_group": process_mesh.processes,
"dims_mapping": dims_mapping "dims_mapping": dims_mapping,
} }
return dist_attr return dist_attr
...@@ -782,19 +842,26 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -782,19 +842,26 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
dist_param_dict(dict): parameters' value of current rank. dist_param_dict(dict): parameters' value of current rank.
""" """
assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None."
assert isinstance(dist_param_dict, dict), \ assert isinstance(
"The type of 'dist_param_dict' should be 'dict', but got {}.".format( dist_param_dict, dict
str(type(dist_param_dict))) ), "The type of 'dist_param_dict' should be 'dict', but got {}.".format(
str(type(dist_param_dict))
)
for name, value in dist_param_dict.items(): for name, value in dist_param_dict.items():
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError("The key of 'dist_param_dict' is parameter's name, " raise TypeError(
"and its type should be 'str', but got {}.".format( "The key of 'dist_param_dict' is parameter's name, "
str(type(name)))) "and its type should be 'str', but got {}.".format(
str(type(name))
)
)
if not isinstance(value, list) or not all( if not isinstance(value, list) or not all(
isinstance(v, np.ndarray) for v in value): isinstance(v, np.ndarray) for v in value
):
raise TypeError( raise TypeError(
"The value of 'dist_param_dict' is parameter's value of all ranks, " "The value of 'dist_param_dict' is parameter's value of all ranks, "
"and its type should be 'list(numpy.ndarray)'.") "and its type should be 'list(numpy.ndarray)'."
)
if cur_dist_attr is None: if cur_dist_attr is None:
return {} return {}
...@@ -822,7 +889,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -822,7 +889,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
cur_dims_mapping = cur_attr["dims_mapping"] cur_dims_mapping = cur_attr["dims_mapping"]
if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping:
complete_param = _merge_parameter_with_dist_attr( complete_param = _merge_parameter_with_dist_attr(
pre_param, pre_attr) pre_param, pre_attr
)
dist_param_dict[var_name] = complete_param dist_param_dict[var_name] = complete_param
else: else:
complete_param = pre_param[0] complete_param = pre_param[0]
...@@ -830,7 +898,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -830,7 +898,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping:
sliced_param = _slice_parameter_with_dist_attr( sliced_param = _slice_parameter_with_dist_attr(
complete_param, cur_attr) complete_param, cur_attr
)
dist_param_dict[var_name] = sliced_param dist_param_dict[var_name] = sliced_param
for var_name in pre_dist_attr: for var_name in pre_dist_attr:
...@@ -841,67 +910,81 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): ...@@ -841,67 +910,81 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
if param_not_in_pre: if param_not_in_pre:
warnings.warn( warnings.warn(
"Parameters '{}' are not found in last training process.".format( "Parameters '{}' are not found in last training process.".format(
str(param_not_in_pre))) str(param_not_in_pre)
)
)
if param_not_in_cur: if param_not_in_cur:
warnings.warn( warnings.warn(
"Parameters '{}' are not found in current training process.".format( "Parameters '{}' are not found in current training process.".format(
str(param_not_in_cur))) str(param_not_in_cur)
)
)
return dist_param_dict return dist_param_dict
def _merge_parameter_with_dist_attr(param_list, dist_attr): def _merge_parameter_with_dist_attr(param_list, dist_attr):
""" Merge parameter with distributed attribute """ """Merge parameter with distributed attribute"""
from .reshard import Resharder from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# get the complete shape of the parameter # get the complete shape of the parameter
complete_shape = Resharder.compute_complete_shape(param_list[0].shape, complete_shape = Resharder.compute_complete_shape(
process_shape, param_list[0].shape, process_shape, dims_mapping
dims_mapping) )
# merge the parameter with dist_attr # merge the parameter with dist_attr
partition_param_list = [] partition_param_list = []
merged_partiton = [] merged_partiton = []
for process in process_group: for process in process_group:
partition_index = Resharder.compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group
)
index = process_group.index(process) index = process_group.index(process)
if partition_index not in merged_partiton: if partition_index not in merged_partiton:
merged_partiton.append(partition_index) merged_partiton.append(partition_index)
_merge_parameter(partition_param_list, param_list[index], _merge_parameter(
partition_index, complete_shape) partition_param_list,
param_list[index],
partition_index,
complete_shape,
)
assert len(partition_param_list) == 1 or not partition_param_list, \ assert (
"Fail to merge parameter" len(partition_param_list) == 1 or not partition_param_list
), "Fail to merge parameter"
complete_param = partition_param_list[0][0] complete_param = partition_param_list[0][0]
return complete_param return complete_param
def _slice_parameter_with_dist_attr(param, dist_attr): def _slice_parameter_with_dist_attr(param, dist_attr):
""" Slice parameter with distributed attribute """ """Slice parameter with distributed attribute"""
param = np.array(param) if isinstance(param, param = (
paddle.fluid.LoDTensor) else param np.array(param) if isinstance(param, paddle.fluid.LoDTensor) else param
)
dims_mapping = dist_attr["dims_mapping"] dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"] process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"] process_group = dist_attr["process_group"]
# slice the parameter with dist_attr # slice the parameter with dist_attr
partition_index_list = _get_split_indices(param.shape, dims_mapping, partition_index_list = _get_split_indices(
process_shape, process_group) param.shape, dims_mapping, process_shape, process_group
sliced_param_list = _slice_parameter(param, partition_index_list, )
len(partition_index_list)) sliced_param_list = _slice_parameter(
param, partition_index_list, len(partition_index_list)
)
# get the current parameter's index in sliced_param_list # get the current parameter's index in sliced_param_list
rank_id = paddle.distributed.get_rank() rank_id = paddle.distributed.get_rank()
sliced_param_index = _get_sliced_param_index(rank_id, param.shape, sliced_param_index = _get_sliced_param_index(
dims_mapping, process_shape, rank_id, param.shape, dims_mapping, process_shape, process_group
process_group) )
sliced_param = sliced_param_list[sliced_param_index] sliced_param = sliced_param_list[sliced_param_index]
return sliced_param return sliced_param
def _merge_parameter(partition_param_list, param, partition_index, def _merge_parameter(
complete_shape): partition_param_list, param, partition_index, complete_shape
):
""" """
Merge partitial parameters to a complete one. Merge partitial parameters to a complete one.
...@@ -935,19 +1018,30 @@ def _merge_parameter(partition_param_list, param, partition_index, ...@@ -935,19 +1018,30 @@ def _merge_parameter(partition_param_list, param, partition_index,
else: else:
i = 0 i = 0
while i < len(partition_param_list): while i < len(partition_param_list):
concat_axis, first_order, new_partition = Resharder.compute_concat_info( (
partition_param_list[i][1], partition_index) concat_axis,
first_order,
new_partition,
) = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index
)
if concat_axis != -1: if concat_axis != -1:
if first_order == 0: if first_order == 0:
new_param = np.concatenate( new_param = np.concatenate(
(partition_param_list[i][0], param), axis=concat_axis) (partition_param_list[i][0], param), axis=concat_axis
)
else: else:
new_param = np.concatenate( new_param = np.concatenate(
(param, partition_param_list[i][0]), axis=concat_axis) (param, partition_param_list[i][0]), axis=concat_axis
)
partition_param_list.pop(i) partition_param_list.pop(i)
_merge_parameter(partition_param_list, new_param, new_partition, _merge_parameter(
complete_shape) partition_param_list,
new_param,
new_partition,
complete_shape,
)
break break
i += 1 i += 1
...@@ -975,19 +1069,21 @@ def _slice_parameter(complete_param, partition_index_list, length): ...@@ -975,19 +1069,21 @@ def _slice_parameter(complete_param, partition_index_list, length):
""" """
sliced_param_list = [] sliced_param_list = []
axis = len(complete_param.shape) - length axis = len(complete_param.shape) - length
sliced_param = np.split(complete_param, sliced_param = np.split(
partition_index_list[axis], complete_param, partition_index_list[axis], axis=axis
axis=axis) )
if length == 1: if length == 1:
return sliced_param return sliced_param
for param in sliced_param: for param in sliced_param:
sliced_param_list.extend( sliced_param_list.extend(
_slice_parameter(param, partition_index_list, length - 1)) _slice_parameter(param, partition_index_list, length - 1)
)
return sliced_param_list return sliced_param_list
def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, def _get_sliced_param_index(
process_group): rank, complete_shape, dims_mapping, process_shape, process_group
):
""" """
Get sliced_param's index of current rank in all sliced parameters list. Get sliced_param's index of current rank in all sliced parameters list.
...@@ -1006,7 +1102,7 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, ...@@ -1006,7 +1102,7 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_group = [0, 1, 2] process_group = [0, 1, 2]
slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3)
# slice_param: # slice_param:
# [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])]
index = _get_sliced_param_index(rank, complete_shape, dims_mapping index = _get_sliced_param_index(rank, complete_shape, dims_mapping
...@@ -1015,10 +1111,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, ...@@ -1015,10 +1111,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
""" """
from .reshard import Resharder from .reshard import Resharder
partition_index = Resharder.compute_partition_index(rank, complete_shape, partition_index = Resharder.compute_partition_index(
dims_mapping, rank, complete_shape, dims_mapping, process_shape, process_group
process_shape, )
process_group)
sliced_param_index = 0 sliced_param_index = 0
for i, shape in enumerate(complete_shape): for i, shape in enumerate(complete_shape):
if dims_mapping[i] == -1: if dims_mapping[i] == -1:
...@@ -1033,8 +1128,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, ...@@ -1033,8 +1128,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
return sliced_param_index return sliced_param_index
def _get_split_indices(complete_shape, dims_mapping, process_shape, def _get_split_indices(
process_group): complete_shape, dims_mapping, process_shape, process_group
):
""" """
Get split indices of every dimension. Get split indices of every dimension.
...@@ -1059,15 +1155,20 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape, ...@@ -1059,15 +1155,20 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
split_indices_list = [] split_indices_list = []
for process in process_group: for process in process_group:
partition_index = Resharder.compute_partition_index( partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group
)
if split_indices_list: if split_indices_list:
for dim in range(len(partition_index)): for dim in range(len(partition_index)):
split_indices_list[dim].extend(partition_index[dim]) split_indices_list[dim].extend(partition_index[dim])
else: else:
split_indices_list = partition_index split_indices_list = partition_index
split_indices_list = list( split_indices_list = list(
map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list, map(
complete_shape)) lambda x, y: list(set(x) - set([y]) - set([0])),
split_indices_list,
complete_shape,
)
)
split_indices_list = [sorted(x) for x in split_indices_list] split_indices_list = [sorted(x) for x in split_indices_list]
return split_indices_list return split_indices_list
...@@ -1086,8 +1187,10 @@ def set_grad_var_shape(program, dist_context): ...@@ -1086,8 +1187,10 @@ def set_grad_var_shape(program, dist_context):
if int(op.attr('op_role')) != int(OpRole.Backward): if int(op.attr('op_role')) != int(OpRole.Backward):
continue continue
if int(block.ops[idx-1].attr('op_role')) == int(OpRole.Forward) or \ if (
int(block.ops[idx-1].attr('op_role')) == 257: int(block.ops[idx - 1].attr('op_role')) == int(OpRole.Forward)
or int(block.ops[idx - 1].attr('op_role')) == 257
):
appended_grad_times += 1 appended_grad_times += 1
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
...@@ -1105,61 +1208,102 @@ def set_grad_var_shape(program, dist_context): ...@@ -1105,61 +1208,102 @@ def set_grad_var_shape(program, dist_context):
continue continue
if var_name in grad_var_to_var[appended_grad_times]: if var_name in grad_var_to_var[appended_grad_times]:
forward_var_name = grad_var_to_var[appended_grad_times][ forward_var_name = grad_var_to_var[appended_grad_times][
var_name] var_name
]
else: else:
forward_var_name = var_name[:var_name.find("@GRAD")] forward_var_name = var_name[: var_name.find("@GRAD")]
if op.type in [ if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast", "c_allreduce_sum",
"fill_any_like" "c_identity",
"scale",
"cast",
"fill_any_like",
]: ]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": elif (
op.type == "matmul_v2_grad"
or op.type == "matmul_grad"
or op.type == "mul_grad"
):
forward_var_name = None forward_var_name = None
for output_name in op.output_names: for output_name in op.output_names:
if var_name in op.output(output_name): if var_name in op.output(output_name):
assert "@GRAD" in output_name assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")] input_name = output_name[: output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1 assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0] forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None assert forward_var_name is not None
need_set_shape_list = [ need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad", "reshape2_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2", "softmax_with_cross_entropy_grad",
"dropout_grad", "tanh_grad", "slice", "assign", "transpose2_grad",
"matmul_v2_triple_grad", "elementwise_add_triple_grad", "softmax_grad",
"fill_constant", "sqrt_grad", "cross_entropy_grad2",
"dropout_grad",
"tanh_grad",
"slice",
"assign",
"matmul_v2_triple_grad",
"elementwise_add_triple_grad",
"fill_constant",
"sqrt_grad",
"fused_softmax_mask_upper_triangle_grad", "fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "relu_grad" "flatten_contiguous_range_grad",
"relu_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2", "reshape2",
"softmax", "cross_entropy2", "dropout", "tanh", "softmax_with_cross_entropy",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", "transpose2",
"elementwise_add_grad_grad", "shape", "sqrt", "softmax",
"fused_softmax_mask_upper_triangle", "flatten_contiguous_range", "cross_entropy2",
"relu" "dropout",
"tanh",
["slice_grad", "c_allgather"],
"assign",
"matmul_v2_grad_grad",
"elementwise_add_grad_grad",
"shape",
"sqrt",
"fused_softmax_mask_upper_triangle",
"flatten_contiguous_range",
"relu",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
idx = need_set_shape_list.index(op.type) idx = need_set_shape_list.index(op.type)
forward_op_name = forward_list[idx] forward_op_name = forward_list[idx]
if forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names: if (
op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op.type in forward_op_name
forward_op) and forward_var_name in forward_op.input_arg_names
):
op_dist_attr = (
dist_context.get_op_dist_attr_for_program(
forward_op
)
)
break break
forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name) forward_var_name
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}" )
assert (
forward_input_dist_attr is not None
), f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name] forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( forward_var_dist_attr = (
forward_var) dist_context.get_tensor_dist_attr_for_program(forward_var)
)
assert forward_var_dist_attr is not None assert forward_var_dist_attr is not None
grad_var = vars[var_name] grad_var = vars[var_name]
ref_shape = infer_shape(block, forward_var, forward_var_dist_attr, ref_shape = infer_shape(
forward_input_dist_attr) block,
forward_var,
forward_var_dist_attr,
forward_input_dist_attr,
)
if list(grad_var.shape) != ref_shape: if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape) grad_var.desc.set_shape(ref_shape)
...@@ -1171,28 +1315,33 @@ OpRole = core.op_proto_and_checker_maker.OpRole ...@@ -1171,28 +1315,33 @@ OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op): def is_forward_op(op):
op_role = int(op.attr('op_role')) op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.Forward) return OP_ROLE_KEY in op.attr_names and (
or op_role == int(OpRole.Loss)) op_role == int(OpRole.Forward) or op_role == int(OpRole.Loss)
)
def is_backward_op(op): def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and int(
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) op.all_attrs()[OP_ROLE_KEY]
) & int(OpRole.Backward)
def is_optimize_op(op): def is_optimize_op(op):
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and int(
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) op.all_attrs()[OP_ROLE_KEY]
) & int(OpRole.Optimize)
def is_lr_sched_op(op): def is_lr_sched_op(op):
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and int(
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize.LRSched) op.all_attrs()[OP_ROLE_KEY]
) & int(OpRole.Optimize.LRSched)
def is_loss_op(op): def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and int(
int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss)) op.all_attrs()[OP_ROLE_KEY]
) == (int(OpRole.Forward) | int(OpRole.Loss))
def is_loss_grad_op(op): def is_loss_grad_op(op):
...@@ -1203,8 +1352,9 @@ def is_loss_grad_op(op): ...@@ -1203,8 +1352,9 @@ def is_loss_grad_op(op):
def is_gradient_clip_op(op): def is_gradient_clip_op(op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") and op.desc.attr(
and op.desc.attr("op_namescope").startswith("/gradient_clip") "op_namescope"
).startswith("/gradient_clip")
def is_prim_op(op): def is_prim_op(op):
...@@ -1215,8 +1365,9 @@ def get_loss_op(block): ...@@ -1215,8 +1365,9 @@ def get_loss_op(block):
loss_ops = [] loss_ops = []
for op in block.ops: for op in block.ops:
if is_loss_op(op): if is_loss_op(op):
assert len(op.desc.output_arg_names() assert (
) == 1, "loss op should only output loss var" len(op.desc.output_arg_names()) == 1
), "loss op should only output loss var"
loss_ops.append(op) loss_ops.append(op)
assert len(loss_ops) == 1, "num of loss op is not equal to one" assert len(loss_ops) == 1, "num of loss op is not equal to one"
...@@ -1236,7 +1387,8 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): ...@@ -1236,7 +1387,8 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, process_mesh, ref_mapping, ctx): new_op, process_mesh, ref_mapping, ctx
):
assert process_mesh is not None assert process_mesh is not None
assert ref_mapping is not None assert ref_mapping is not None
...@@ -1270,9 +1422,11 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): ...@@ -1270,9 +1422,11 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]): for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \ assert (
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ mapping == -1
.format(op_desc.type(), idx, mapping) ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0]) batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
...@@ -1282,23 +1436,31 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): ...@@ -1282,23 +1436,31 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]): for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \ assert (
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ mapping == -1
.format(op_desc.type(), idx, mapping) ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0]) batch_dim_mappings.append(dims_mapping[0])
else: else:
assert dims_mapping[0] == -1, \ assert (
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ dims_mapping[0] == -1
.format(op_desc.type(), mapping) ), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part.".format(
op_desc.type(), mapping
)
if len(dims_mapping) > 2: if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]): for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \ assert (
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ mapping == -1
.format(op_desc.type(), idx, mapping) ), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[1]) batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." assert (
compatible_dim_mapping is not None
), "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name) serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
...@@ -1344,8 +1506,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): ...@@ -1344,8 +1506,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]): for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len - new_idx = (
input_dims_mapping_lens[arg_name]) + i max_dims_mapping_len - input_dims_mapping_lens[arg_name]
) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping) dims_mapping_list.append(new_dims_mapping)
else: else:
...@@ -1357,7 +1520,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): ...@@ -1357,7 +1520,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
dims_mapping_list.append(dims_mapping) dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping." assert (
compatible_dims_mapping is not None
), "There is no compatible dim mapping."
for arg_name in input_arg_names: for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
...@@ -1365,55 +1530,64 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): ...@@ -1365,55 +1530,64 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
-1 for _ in range(input_dims_mapping_lens[arg_name]) -1 for _ in range(input_dims_mapping_lens[arg_name])
] ]
for i in range(input_dims_mapping_lens[arg_name]): for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len - new_idx = (
input_dims_mapping_lens[arg_name]) + i max_dims_mapping_len - input_dims_mapping_lens[arg_name]
) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx] new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]: if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True changed = True
else: else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, op_dist_attr.set_input_dims_mapping(
compatible_dims_mapping) arg_name, compatible_dims_mapping
)
changed = True changed = True
for arg_name in output_arg_names: for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping: if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name, op_dist_attr.set_output_dims_mapping(
compatible_dims_mapping) arg_name, compatible_dims_mapping
)
changed = True changed = True
return changed return changed
def get_all_distributed_main_program(serial_program_info, dist_context, def get_all_distributed_main_program(
parallelizer): serial_program_info, dist_context, parallelizer
):
"Get all distributed main programs by dist_context." "Get all distributed main programs by dist_context."
from .dist_context import DistributedOperatorContext, DistributedContext from .dist_context import DistributedOperatorContext, DistributedContext
cluster = serial_program_info.cluster cluster = serial_program_info.cluster
copied_parallelizer = copy.deepcopy(parallelizer) copied_parallelizer = copy.deepcopy(parallelizer)
all_dist_main_program = [] all_dist_main_program = []
ranks = paddle.distributed.get_world_size() if cluster is None else len( ranks = (
cluster.get_all_devices("GPU")) paddle.distributed.get_world_size()
if cluster is None
else len(cluster.get_all_devices("GPU"))
)
for rank_id in range(ranks): for rank_id in range(ranks):
used_dist_context = copy.deepcopy(dist_context) used_dist_context = copy.deepcopy(dist_context)
used_dist_context._dist_op_context = DistributedOperatorContext() used_dist_context._dist_op_context = DistributedOperatorContext()
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( (
rank_id, used_dist_context) _,
_,
dist_startup_program,
dist_main_program,
_,
) = copied_parallelizer._get_dist_program(rank_id, used_dist_context)
all_dist_main_program.append(dist_main_program) all_dist_main_program.append(dist_main_program)
return all_dist_main_program return all_dist_main_program
class SerialProgramInfo: class SerialProgramInfo:
def __init__(
def __init__(self, self, train_program, satrtup_program, loss, optimizer, cluster=None
train_program, ):
satrtup_program,
loss,
optimizer,
cluster=None):
self._train_program = train_program self._train_program = train_program
self._startup_program = satrtup_program self._startup_program = satrtup_program
self._loss = loss self._loss = loss
...@@ -1442,7 +1616,6 @@ class SerialProgramInfo: ...@@ -1442,7 +1616,6 @@ class SerialProgramInfo:
def get_standalone_cost_data(distributed_programs): def get_standalone_cost_data(distributed_programs):
def _compute_runtime(op_cost, op, vars): def _compute_runtime(op_cost, op, vars):
runtime = 0 runtime = 0
try: try:
...@@ -1455,32 +1628,47 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1455,32 +1628,47 @@ def get_standalone_cost_data(distributed_programs):
parsed_info = op_config.split("\n") parsed_info = op_config.split("\n")
variable = "(Variable)" variable = "(Variable)"
for info in parsed_info: for info in parsed_info:
variable = "(Variable)" if "(Variable)" in info else "(list<Variable>" variable = (
"(Variable)" if "(Variable)" in info else "(list<Variable>"
)
if variable in info: if variable in info:
arg_name_lower = info[:info.find(variable) - 1] arg_name_lower = info[: info.find(variable) - 1]
shape_left_boundary = info.find("[") shape_left_boundary = info.find("[")
shape_right_boundary = info.find("]") shape_right_boundary = info.find("]")
assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed." assert (
shape = info[shape_left_boundary + shape_left_boundary > 0
1:shape_right_boundary].split(",") and shape_right_boundary > 0
and shape_right_boundary > shape_left_boundary
), "Get shape failed."
shape = info[
shape_left_boundary + 1 : shape_right_boundary
].split(",")
shape = list(map(lambda x: int(x.strip()), shape)) shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1 dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape) total_static_input_size += reduce(lambda x, y: x * y, shape)
if op.type == "c_embedding": if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids" arg_name_lower = (
"w" if arg_name_lower == "weight" else "ids"
)
for arg_name in op.input_names: for arg_name in op.input_names:
if arg_name.lower() == arg_name_lower: if arg_name.lower() == arg_name_lower:
for var_name in op.input(arg_name): for var_name in op.input(arg_name):
var = vars[var_name] var = vars[var_name]
total_actual_input_size += reduce( total_actual_input_size += reduce(
lambda x, y: x * y, var.shape) lambda x, y: x * y, var.shape
)
break break
assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed." assert (
total_static_input_size > 0 and total_actual_input_size > 0
), "Get input size failed."
actual_runtime = total_actual_input_size / total_static_input_size * runtime actual_runtime = (
total_actual_input_size / total_static_input_size * runtime
)
return actual_runtime return actual_runtime
import paddle.cost_model as cm import paddle.cost_model as cm
cost_model = cm.CostModel() cost_model = cm.CostModel()
cost_model.static_cost_data() cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2 DEFAULT_MULTIPLE = 2
...@@ -1491,13 +1679,16 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1491,13 +1679,16 @@ def get_standalone_cost_data(distributed_programs):
"reshape2": "reshape", "reshape2": "reshape",
"unsqueeze2": "unsqueeze", "unsqueeze2": "unsqueeze",
"reduce_sum": "sum", "reduce_sum": "sum",
"elementwise_div": "divide" "elementwise_div": "divide",
} }
standalone_cost_data = [] standalone_cost_data = []
# skip ops # skip ops
not_enum_ops = [ not_enum_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "assign" "create_py_reader",
"create_double_buffer_reader",
"read",
"assign",
] ]
for distributed_program in distributed_programs: for distributed_program in distributed_programs:
cost_data = {} cost_data = {}
...@@ -1507,26 +1698,33 @@ def get_standalone_cost_data(distributed_programs): ...@@ -1507,26 +1698,33 @@ def get_standalone_cost_data(distributed_programs):
if op.type in not_enum_ops: if op.type in not_enum_ops:
cost_data[op.desc.id()] = runtime cost_data[op.desc.id()] = runtime
continue continue
dtype = str(vars[op.input_arg_names[0]].dtype dtype = (
) if op.input_arg_names else "float32" str(vars[op.input_arg_names[0]].dtype)
if op.input_arg_names
else "float32"
)
if int(op.attr('op_role')) == int(OpRole.Backward): if int(op.attr('op_role')) == int(OpRole.Backward):
if "_grad" in op.type: if "_grad" in op.type:
forward_op_name = op.type[:-5] forward_op_name = op.type[:-5]
if forward_op_name in OP_NAME_MAPPING.keys(): if forward_op_name in OP_NAME_MAPPING.keys():
forward_op_name = OP_NAME_MAPPING[forward_op_name] forward_op_name = OP_NAME_MAPPING[forward_op_name]
op_cost = cost_model.get_static_op_time(forward_op_name, op_cost = cost_model.get_static_op_time(
forward=False, forward_op_name, forward=False, dtype=dtype
dtype=dtype) )
if op_cost: if op_cost:
runtime = _compute_runtime(op_cost, op, vars) runtime = _compute_runtime(op_cost, op, vars)
else: else:
op_cost = cost_model.get_static_op_time(forward_op_name, op_cost = cost_model.get_static_op_time(
dtype=dtype) forward_op_name, dtype=dtype
)
if op_cost: if op_cost:
runtime = 2 * _compute_runtime(op_cost, op, vars) runtime = 2 * _compute_runtime(op_cost, op, vars)
elif int(op.attr('op_role')) == int(OpRole.Forward): elif int(op.attr('op_role')) == int(OpRole.Forward):
op_name = OP_NAME_MAPPING[ op_name = (
op.type] if op.type in OP_NAME_MAPPING.keys() else op.type OP_NAME_MAPPING[op.type]
if op.type in OP_NAME_MAPPING.keys()
else op.type
)
op_cost = cost_model.get_static_op_time(op_name) op_cost = cost_model.get_static_op_time(op_name)
if op_cost: if op_cost:
runtime = _compute_runtime(op_cost, op, vars) runtime = _compute_runtime(op_cost, op, vars)
...@@ -1565,7 +1763,8 @@ def to_list(value): ...@@ -1565,7 +1763,8 @@ def to_list(value):
def debug_program(program, path, name): def debug_program(program, path, name):
filename = os.path.join( filename = os.path.join(
path, name + '_program' + ".%d" % (paddle.distributed.get_rank())) path, name + '_program' + ".%d" % (paddle.distributed.get_rank())
)
with open(filename, 'w') as f: with open(filename, 'w') as f:
f.write(str(program)) f.write(str(program))
...@@ -1599,9 +1798,11 @@ def get_lr(optimizer): ...@@ -1599,9 +1798,11 @@ def get_lr(optimizer):
return optimizer._learning_rate() return optimizer._learning_rate()
else: else:
raise TypeError( raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(
type(optimizer)
) )
)
def initialize_pg_in_full_mode(all_process_groups, cur_rank): def initialize_pg_in_full_mode(all_process_groups, cur_rank):
...@@ -1631,21 +1832,28 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): ...@@ -1631,21 +1832,28 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
if is_send: if is_send:
recv_rank = process_group.ranks[1] recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[ recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":") recv_rank
].split(":")
connect_port = int(recv_rank_port) + magic_num connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET, client_socket = socket.socket(
socket.SOCK_STREAM) socket.AF_INET, socket.SOCK_STREAM
)
client_socket.connect((recv_rank_ip, connect_port)) client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(cur_rank).encode('utf-8')) client_socket.send(str(cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8') rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank) rank = int(rank)
if rank != recv_rank: if rank != recv_rank:
raise ValueError( raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}." "Please check comm pair, the recv rank should be {} but got {}.".format(
.format(recv_rank, rank)) recv_rank, rank
)
)
else: else:
print("It is able to instantiate {} as sender now.".format( print(
process_group.ranks)) "It is able to instantiate {} as sender now.".format(
process_group.ranks
)
)
client_socket.close() client_socket.close()
else: else:
send_rank = process_group.ranks[0] send_rank = process_group.ranks[0]
...@@ -1657,10 +1865,45 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): ...@@ -1657,10 +1865,45 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
has_recv_by_socket.append(rank) has_recv_by_socket.append(rank)
else: else:
client_sockets[send_rank].send( client_sockets[send_rank].send(
str(cur_rank).encode("utf-8")) str(cur_rank).encode("utf-8")
)
client_sockets[send_rank].close() client_sockets[send_rank].close()
print("It is able to instantiate {} as recver now.". print(
format(process_group.ranks)) "It is able to instantiate {} as recver now.".format(
process_group.ranks
)
)
break break
process_group.instantiate() process_group.instantiate()
server_socket.close() server_socket.close()
def get_input_split_info(cur_rank, var, dist_context):
# deduce how the input data is split among the cluster
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank)
else:
rank_id = cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
batch_size_axis,
rank_id,
)
return len(group_ranks), group_ranks.index(rank_id)
return 1, 0
def validate_opt(optimizer):
if optimizer is not None:
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
...@@ -20,12 +20,31 @@ from paddle.fluid.framework import default_main_program, default_startup_program ...@@ -20,12 +20,31 @@ from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import register_pass from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.process_group import get_world_process_group set_var_dist_attr,
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str )
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op, OP_ROLE_KEY, OpRole get_world_process_group,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32,
_valid_types,
_dtype_to_str,
)
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.utils import (
is_forward_op,
is_backward_op,
OP_ROLE_KEY,
OpRole,
)
from .auto_parallel_amp import AMPPass from .auto_parallel_amp import AMPPass
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
...@@ -39,11 +58,15 @@ __amp_skip_ops__ = [ ...@@ -39,11 +58,15 @@ __amp_skip_ops__ = [
def set_op_dtype_to_fp16(op): def set_op_dtype_to_fp16(op):
if op.has_attr('in_dtype') and op.attr( if (
'in_dtype') == core.VarDesc.VarType.FP32: op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16) op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype') and op.attr( if (
'out_dtype') == core.VarDesc.VarType.FP32: op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', core.VarDesc.VarType.FP16)
...@@ -63,7 +86,12 @@ def _keep_fp32_input(op, in_name): ...@@ -63,7 +86,12 @@ def _keep_fp32_input(op, in_name):
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']: if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in { return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" 'LnScale',
'LnBias',
'Ln2Scale',
'Ln2Bias',
"Ln1Scale",
"Ln1Bias",
} }
# backward # backward
if op_type in ['batch_norm_grad']: if op_type in ['batch_norm_grad']:
...@@ -83,8 +111,12 @@ def _keep_fp32_output(op, out_name): ...@@ -83,8 +111,12 @@ def _keep_fp32_output(op, out_name):
return out_name not in {'Y', 'ConvX', 'ConvZ'} return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']: if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in { return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', 'LnMean',
'Ln1Variance' 'LnVariance',
'Ln2Mean',
'Ln2Variance',
'Ln1Mean',
'Ln1Variance',
} }
# backward # backward
if op_type in ['layer_norm_grad']: if op_type in ['layer_norm_grad']:
...@@ -95,24 +127,28 @@ def _keep_fp32_output(op, out_name): ...@@ -95,24 +127,28 @@ def _keep_fp32_output(op, out_name):
class FP16State(object): class FP16State(object):
def __init__(
def __init__(self, self,
program, program,
amp_list, amp_list,
dist_context, dist_context,
use_fp16_guard, use_fp16_guard,
input_data_var_names=None): input_data_var_names=None,
):
self.program = program self.program = program
self.amp_list = amp_list self.amp_list = amp_list
self.use_fp16_guard = use_fp16_guard self.use_fp16_guard = use_fp16_guard
self.dist_context = dist_context self.dist_context = dist_context
self.grad_op_to_op_map = self.dist_context.dist_op_context.grad_op_id_to_op_id self.grad_op_to_op_map = (
self.dist_context.dist_op_context.grad_op_id_to_op_id
)
if input_data_var_names: if input_data_var_names:
self.input_data_var_names = input_data_var_names self.input_data_var_names = input_data_var_names
else: else:
self.input_data_var_names = [] self.input_data_var_names = []
self._op_fp16_dict = { self._op_fp16_dict = (
} # op_id --> True/False. 'True' means that the op is should run in fp16 mode. {}
) # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
# a trick to determine leaf tensor node in program {varname: generator_op_id} # a trick to determine leaf tensor node in program {varname: generator_op_id}
self.forward_non_leaf_tensors = {} self.forward_non_leaf_tensors = {}
# record the cast ops that are inserted for a forward # record the cast ops that are inserted for a forward
...@@ -126,7 +162,7 @@ class FP16State(object): ...@@ -126,7 +162,7 @@ class FP16State(object):
def _build_state(self): def _build_state(self):
""" """
mark the execution mode (fp16 or fp32) for ops in all blocks mark the execution mode (fp16 or fp32) for ops in all blocks
include forward ops & backward ops include forward ops & backward ops
""" """
# mark op dtype # mark op dtype
...@@ -156,8 +192,9 @@ class FP16State(object): ...@@ -156,8 +192,9 @@ class FP16State(object):
if op.type == "assign" and "array_" in op.input_arg_names[0]: if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
return return
if _need_keep_fp32(op, self.amp_list.unsupported_list, if _need_keep_fp32(
self.use_fp16_guard): op, self.amp_list.unsupported_list, self.use_fp16_guard
):
self._op_fp16_dict[op.desc.original_id()] = False self._op_fp16_dict[op.desc.original_id()] = False
else: else:
self._op_fp16_dict[op.desc.original_id()] = True self._op_fp16_dict[op.desc.original_id()] = True
...@@ -170,8 +207,9 @@ class FP16State(object): ...@@ -170,8 +207,9 @@ class FP16State(object):
if op.desc.original_id() in self.grad_op_to_op_map: if op.desc.original_id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()] fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op)) assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
self._op_fp16_dict[ self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
op.desc.original_id()] = self._op_fp16_dict[fwd_op_id] fwd_op_id
]
if int(op.attr('op_role')) == 257: if int(op.attr('op_role')) == 257:
self.is_train = True self.is_train = True
...@@ -181,7 +219,8 @@ class FP16State(object): ...@@ -181,7 +219,8 @@ class FP16State(object):
try: try:
var = block.var(var_name) var = block.var(var_name)
except ValueError as e: except ValueError as e:
var = self.program.global_block().var(var_name) var = block._var_recursive(var_name)
# var = self.program.global_block().var(var_name)
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
...@@ -196,13 +235,18 @@ class FP16State(object): ...@@ -196,13 +235,18 @@ class FP16State(object):
for op in block.ops: for op in block.ops:
if is_forward_op(op): if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op.desc.original_id()) == True \ if (
or op.type == "cast": self._is_fp16_op(op.desc.original_id()) == True
or op.type == "cast"
):
for in_name in op.input_names: for in_name in op.input_names:
if _keep_fp32_input(op, in_name): if _keep_fp32_input(op, in_name):
continue continue
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
if in_var_name not in self.forward_non_leaf_tensors and in_var_name not in self.input_data_var_names: if (
in_var_name not in self.forward_non_leaf_tensors
and in_var_name not in self.input_data_var_names
):
self.set_var_to_fp16(in_var_name, block) self.set_var_to_fp16(in_var_name, block)
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if _keep_fp32_output(op, out_name):
...@@ -248,22 +292,42 @@ class FP16State(object): ...@@ -248,22 +292,42 @@ class FP16State(object):
elif is_forward_op(op): elif is_forward_op(op):
if self._is_fp16_op(op.desc.original_id()) == False: if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_forward_cast_ops( num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16, op,
core.VarDesc.VarType.FP32, self.dist_context) idx,
block,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) == True: elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_forward_cast_ops( num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32, op,
core.VarDesc.VarType.FP16, self.dist_context) idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
self.dist_context,
)
elif is_backward_op(op): elif is_backward_op(op):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(op.desc.original_id()) == False: if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_backward_cast_ops( num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16, op,
core.VarDesc.VarType.FP32, self.dist_context) idx,
block,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) == True: elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_backward_cast_ops( num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32, op,
core.VarDesc.VarType.FP16, self.dist_context) idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
self.dist_context,
)
elif op.type == "sum": elif op.type == "sum":
# all inputs dtype of sum should be equal and output dtype should follow input # all inputs dtype of sum should be equal and output dtype should follow input
out_var_name = op.output_arg_names[0] out_var_name = op.output_arg_names[0]
...@@ -271,41 +335,51 @@ class FP16State(object): ...@@ -271,41 +335,51 @@ class FP16State(object):
out_var = block.var(out_var_name) out_var = block.var(out_var_name)
in_var = block._find_var_recursive(in_var_name) in_var = block._find_var_recursive(in_var_name)
for in_var_name in op.input_arg_names: for in_var_name in op.input_arg_names:
assert in_var.dtype == block.var( assert (
in_var_name).dtype, "{}, {}, {}".format( in_var.dtype == block.var(in_var_name).dtype
in_var, block.var(in_var_name), str(op)) ), "{}, {}, {}".format(
in_var, block.var(in_var_name), str(op)
)
out_var.desc.set_dtype(in_var.dtype) out_var.desc.set_dtype(in_var.dtype)
idx += num_cast_ops + 1 idx += num_cast_ops + 1
block._sync_with_cpp() block._sync_with_cpp()
def _insert_forward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, def _insert_forward_cast_ops(
dist_context): self, op, idx, block, src_dtype, dst_dtype, dist_context
):
num_cast_ops = 0 num_cast_ops = 0
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name): op, in_name
):
continue continue
consume_op_attr = dist_context.get_op_dist_attr_for_program(op) consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
assert consume_op_attr is not None assert consume_op_attr is not None
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name) in_var = block._find_var_recursive(in_var_name)
if in_var is None or in_var.type not in _valid_types or in_var.dtype == dst_dtype: if (
in_var is None
or in_var.type not in _valid_types
or in_var.dtype == dst_dtype
):
continue continue
if in_var.dtype == src_dtype: if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str( cast_name = (
dst_dtype) in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
cast_var = block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op.desc.original_id()] += [ self.forward_input_cast_ops[op.desc.original_id()] += [
(cast_name, in_var.name, dst_dtype, src_dtype, in_name) (cast_name, in_var.name, dst_dtype, src_dtype, in_name)
] ]
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name) in_var.name
)
assert in_var_dist_attr is not None assert in_var_dist_attr is not None
# truly insert cast op # truly insert cast op
if cast_var is None or cast_var.dtype != dst_dtype: if cast_var is None or cast_var.dtype != dst_dtype:
...@@ -319,9 +393,11 @@ class FP16State(object): ...@@ -319,9 +393,11 @@ class FP16State(object):
name=cast_name, name=cast_name,
dtype=dst_dtype, dtype=dst_dtype,
persistable=False, persistable=False,
stop_gradient=in_var.stop_gradient) stop_gradient=in_var.stop_gradient,
set_var_dist_attr(dist_context, cast_var, ref_mapping, )
ref_mesh) set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
cast_op = block._insert_op_without_sync( cast_op = block._insert_op_without_sync(
idx, idx,
...@@ -331,23 +407,27 @@ class FP16State(object): ...@@ -331,23 +407,27 @@ class FP16State(object):
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": cast_var.dtype, "out_dtype": cast_var.dtype,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward,
}) },
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context) cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1 num_cast_ops += 1
op._rename_input(in_var.name, cast_name) op._rename_input(in_var.name, cast_name)
consume_op_attr.set_input_dist_attr(cast_name, consume_op_attr.set_input_dist_attr(
in_var_dist_attr) cast_name, in_var_dist_attr
)
if op.has_attr('out_dtype') and op.attr('out_dtype') != -1: if op.has_attr('out_dtype') and op.attr('out_dtype') != -1:
assert op.attr('out_dtype') == dst_dtype assert op.attr('out_dtype') == dst_dtype
return num_cast_ops return num_cast_ops
def _insert_backward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, def _insert_backward_cast_ops(
dist_context): self, op, idx, block, src_dtype, dst_dtype, dist_context
):
num_cast_ops = 0 num_cast_ops = 0
op_id = op.desc.id() op_id = op.desc.id()
...@@ -363,15 +443,21 @@ class FP16State(object): ...@@ -363,15 +443,21 @@ class FP16State(object):
if _keep_fp32_output(op, out_var.name): if _keep_fp32_output(op, out_var.name):
continue continue
assert out_var.dtype == dst_dtype, "{}, {}".format( assert out_var.dtype == dst_dtype, "{}, {}".format(
str(out_var), dst_dtype) str(out_var), dst_dtype
)
for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[ for (
forward_op_id]: cast_name,
src_name,
dst_dtype,
src_dtype,
slot_name,
) in self.forward_input_cast_ops[forward_op_id]:
# rename input # rename input
assert src_name in op.input( assert src_name in op.input(
slot_name), "var: {} not in op's {}. {}".format( slot_name
src_name, slot_name, str(op)) ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name) op._rename_input(src_name, cast_name)
...@@ -393,15 +479,18 @@ class FP16State(object): ...@@ -393,15 +479,18 @@ class FP16State(object):
ref_mapping = grad_dist_attr.dims_mapping ref_mapping = grad_dist_attr.dims_mapping
cast_grad = block.create_var( cast_grad = block.create_var(
name=unique_name.generate_with_ignorable_key("".join( name=unique_name.generate_with_ignorable_key(
[cast_name, '@GRAD'])), "".join([cast_name, '@GRAD'])
),
dtype=dst_dtype, dtype=dst_dtype,
shape=grad.shape, shape=grad.shape,
type=grad.type, type=grad.type,
persistable=grad.persistable, persistable=grad.persistable,
stop_gradient=grad.stop_gradient) stop_gradient=grad.stop_gradient,
)
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
cast_grad, grad_dist_attr) cast_grad, grad_dist_attr
)
op._rename_output(grad_name, cast_grad.name) op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)
...@@ -414,12 +503,14 @@ class FP16State(object): ...@@ -414,12 +503,14 @@ class FP16State(object):
attrs={ attrs={
"in_dtype": dst_dtype, "in_dtype": dst_dtype,
"out_dtype": src_dtype, "out_dtype": src_dtype,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
)
grad.desc.set_dtype(src_dtype) grad.desc.set_dtype(src_dtype)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context) cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1 num_cast_ops += 1
return num_cast_ops return num_cast_ops
...@@ -432,26 +523,34 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -432,26 +523,34 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads: for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'check_finite_and_unscale') e,
"x",
['float16', 'float32', 'float64'],
'check_finite_and_unscale',
)
found_inf = main_block.create_var( found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(
['find_infinite_scale', name])), ".".join(['find_infinite_scale', name])
),
shape=[1], shape=[1],
dtype='bool', dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False,
)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale', new_op = main_block.append_op(
inputs=inputs, type='check_finite_and_unscale',
outputs=outputs, inputs=inputs,
attrs=attrs) outputs=outputs,
attrs=attrs,
)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks new_op_dist_attr.process_mesh = world_process_group.ranks
...@@ -461,10 +560,12 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): ...@@ -461,10 +560,12 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
for g in grads: for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name, new_op_dist_attr.set_input_dims_mapping(
g_dist_attr.dims_mapping) g.name, g_dist_attr.dims_mapping
new_op_dist_attr.set_output_dims_mapping(g.name, )
g_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf return grads, found_inf
...@@ -473,8 +574,9 @@ def _split_grads(params_grads): ...@@ -473,8 +574,9 @@ def _split_grads(params_grads):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \ assert len(fp32_grads) + len(fp16_grads) == len(
"Data types of all grads must be either fp16 or fp32." grads
), "Data types of all grads must be either fp16 or fp32."
return grads, fp32_grads, fp16_grads return grads, fp32_grads, fp16_grads
...@@ -486,37 +588,45 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): ...@@ -486,37 +588,45 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
var = block.var(var_name) var = block.var(var_name)
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
assert var_dist_attr is not None assert var_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(var_name, new_op_dist_attr.set_input_dims_mapping(
var_dist_attr.dims_mapping) var_name, var_dist_attr.dims_mapping
)
for var_name in new_op.output_arg_names: for var_name in new_op.output_arg_names:
var = block.var(var_name) var = block.var(var_name)
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
assert var_dist_attr is not None assert var_dist_attr is not None
new_op_dist_attr.set_output_dims_mapping(var_name, new_op_dist_attr.set_output_dims_mapping(
var_dist_attr.dims_mapping) var_name, var_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def _get_memcopy_idx(block, found_inf_var): def _get_memcopy_idx(block, found_inf_var):
# use reduce_any op for check_nan_inf as the anchor for now # use reduce_any op for check_nan_inf as the anchor for now
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if op.type == 'reduce_any' and op.output_arg_names[ if (
0] == found_inf_var.name: op.type == 'reduce_any'
and op.output_arg_names[0] == found_inf_var.name
):
return idx + 1 return idx + 1
raise RuntimeError( raise RuntimeError(
"not found the correct location for memcopy for found_inf_var.") "not found the correct location for memcopy for found_inf_var."
)
def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
src_name = src_var.name src_name = src_var.name
output_var = block.create_var(name=unique_name.generate_with_ignorable_key( output_var = block.create_var(
src_name.join(['memcopy_'])), name=unique_name.generate_with_ignorable_key(
dtype=src_var.dtype, src_name.join(['memcopy_'])
shape=src_var.shape, ),
type=core.VarDesc.VarType.LOD_TENSOR, dtype=src_var.dtype,
persistable=False, shape=src_var.shape,
stop_gradient=src_var.stop_gradient) type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=src_var.stop_gradient,
)
set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks)
...@@ -527,16 +637,20 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -527,16 +637,20 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
dst_place_type = 1 dst_place_type = 1
else: else:
raise NotImplementedError( raise NotImplementedError(
"direction [{}] is not supported yet.".format(direction)) "direction [{}] is not supported yet.".format(direction)
)
attrs = {'dst_place_type': dst_place_type} attrs = {'dst_place_type': dst_place_type}
new_op = block._insert_op_without_sync(index=idx, new_op = block._insert_op_without_sync(
type='memcpy', index=idx,
inputs={'X': [src_var]}, type='memcpy',
outputs={'Out': [output_var]}, inputs={'X': [src_var]},
attrs=attrs) outputs={'Out': [output_var]},
_set_op_dist_attr_with_ranks(new_op, world_process_group.ranks, block, attrs=attrs,
dist_context) )
_set_op_dist_attr_with_ranks(
new_op, world_process_group.ranks, block, dist_context
)
block._sync_with_cpp() block._sync_with_cpp()
return output_var return output_var
...@@ -564,19 +678,21 @@ def cast_startup_program(): ...@@ -564,19 +678,21 @@ def cast_startup_program():
for op in startup_program.global_block().ops: for op in startup_program.global_block().ops:
if is_initialization_op(op): if is_initialization_op(op):
output_name = op.output_arg_names[0] output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name, if (
None) == core.VarDesc.VarType.FP16: param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
assert op.has_attr( assert op.has_attr(
'dtype' 'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format( ), "initialization op is supported to has dtype attribute but got {}.".format(
str(op)) str(op)
)
if op.attr('dtype') == core.VarDesc.VarType.FP32: if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16) op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16") @register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass): class FP16Pass(AMPPass):
def __init__(self): def __init__(self):
super(FP16Pass, self).__init__() super(FP16Pass, self).__init__()
...@@ -589,16 +705,22 @@ class FP16Pass(AMPPass): ...@@ -589,16 +705,22 @@ class FP16Pass(AMPPass):
amp_list = AutoMixedPrecisionLists( amp_list = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")), set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), None) set(self.get_attr("custom_black_list")),
None,
)
# NOTE don't not change input data dtype, since it is controled by dataloader # NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass # and which is out of control of FP16 Pass
input_data_var_names = [var.name for var in self.get_attr("input_data")] input_data_var_names = [var.name for var in self.get_attr("input_data")]
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
fp16_state = FP16State(main_program, amp_list, self.dist_context, fp16_state = FP16State(
self.get_attr("use_fp16_guard"), main_program,
input_data_var_names) amp_list,
self.dist_context,
self.get_attr("use_fp16_guard"),
input_data_var_names,
)
is_train = fp16_state._build_state() is_train = fp16_state._build_state()
cast_startup_program() cast_startup_program()
...@@ -611,41 +733,63 @@ class FP16Pass(AMPPass): ...@@ -611,41 +733,63 @@ class FP16Pass(AMPPass):
grads, fp32_grads, fp16_grads = _split_grads(params_grads) grads, fp32_grads, fp16_grads = _split_grads(params_grads)
if self.get_attr("use_dynamic_loss_scaling" if (
) or self.get_attr("init_loss_scaling") != 1.0: self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
found_infs = [] found_infs = []
if fp32_grads: if fp32_grads:
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient( _, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32", fp32_grads,
self.dist_context) self._loss_scaling,
"@fp32",
self.dist_context,
)
found_infs.append(found_inf_fp32) found_infs.append(found_inf_fp32)
if fp16_grads: if fp16_grads:
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient( _, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16", fp16_grads,
self.dist_context) self._loss_scaling,
"@fp16",
self.dist_context,
)
found_infs.append(found_inf_fp16) found_infs.append(found_inf_fp16)
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
block = main_program.global_block() block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs) all_infs = paddle.fluid.layers.concat(found_infs)
set_var_dist_attr(self.dist_context, all_infs, [-1], set_var_dist_attr(
world_process_group.ranks) self.dist_context,
all_infs,
[-1],
world_process_group.ranks,
)
new_op = block.ops[-1] new_op = block.ops[-1]
assert new_op.type == "concat" assert new_op.type == "concat"
_set_op_dist_attr_with_ranks(new_op, _set_op_dist_attr_with_ranks(
world_process_group.ranks, new_op,
block, self.dist_context) world_process_group.ranks,
block,
self.dist_context,
)
found_inf = paddle.fluid.layers.reduce_any(all_infs) found_inf = paddle.fluid.layers.reduce_any(all_infs)
set_var_dist_attr(self.dist_context, found_inf, [-1], set_var_dist_attr(
world_process_group.ranks) self.dist_context,
found_inf,
[-1],
world_process_group.ranks,
)
new_op = block.ops[-1] new_op = block.ops[-1]
assert new_op.type == "reduce_any" assert new_op.type == "reduce_any"
_set_op_dist_attr_with_ranks(new_op, _set_op_dist_attr_with_ranks(
world_process_group.ranks, new_op,
block, self.dist_context) world_process_group.ranks,
block,
self.dist_context,
)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
...@@ -660,14 +804,15 @@ class FP16Pass(AMPPass): ...@@ -660,14 +804,15 @@ class FP16Pass(AMPPass):
if self.get_attr("use_optimizer_fp16"): if self.get_attr("use_optimizer_fp16"):
base_opt._multi_precision = False base_opt._multi_precision = False
if isinstance( if isinstance(
base_opt, base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): ):
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy( # found_inf = paddle.tensor.creation._memcpy(
# found_inf, paddle.CPUPlace()) # found_inf, paddle.CPUPlace())
insert_idx = _get_memcopy_idx(block, found_inf) insert_idx = _get_memcopy_idx(block, found_inf)
found_inf = _insert_memcopy(block, insert_idx, found_inf, found_inf = _insert_memcopy(
self.dist_context) block, insert_idx, found_inf, self.dist_context
)
base_opt._set_auxiliary_var('found_inf', found_inf.name) base_opt._set_auxiliary_var('found_inf', found_inf.name)
elif hasattr(base_opt, "_set_auxiliary_var"): elif hasattr(base_opt, "_set_auxiliary_var"):
base_opt._set_auxiliary_var('found_inf', found_inf.name) base_opt._set_auxiliary_var('found_inf', found_inf.name)
...@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS}) ENVS ${dist_ENVS})
...@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS}) py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS}) py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS}) py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
py_test_modules(test_cluster_v2 MODULES test_cluster_v2) py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
...@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface) py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization) py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape) py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign) py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard) test_conditional_block_reshard)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
endif() endif()
...@@ -31,7 +31,10 @@ from paddle.fluid import layers ...@@ -31,7 +31,10 @@ from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.interface import get_collection, CollectionNames from paddle.distributed.auto_parallel.interface import (
get_collection,
CollectionNames,
)
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
...@@ -56,7 +59,6 @@ my_feed_vars = [] ...@@ -56,7 +59,6 @@ my_feed_vars = []
class MyDataset(Dataset): class MyDataset(Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
super(MyDataset, self).__init__() super(MyDataset, self).__init__()
self.num_samples = num_samples self.num_samples = num_samples
...@@ -77,38 +79,38 @@ def get_random_inputs_and_labels(image_shape, label_shape): ...@@ -77,38 +79,38 @@ def get_random_inputs_and_labels(image_shape, label_shape):
def batch_generator_creator(): def batch_generator_creator():
def __reader__(): def __reader__():
for _ in range(batch_num): for _ in range(batch_num):
batch_input, batch_label = get_random_inputs_and_labels( batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, image_size], [batch_size, 1]) [batch_size, image_size], [batch_size, 1]
)
yield batch_input, batch_label yield batch_input, batch_label
return __reader__ return __reader__
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(
def __init__(self, self,
hidden_size=1024, hidden_size=1024,
intermediate_size=4 * 1024, intermediate_size=4 * 1024,
dropout_ratio=0.1, dropout_ratio=0.1,
initializer_range=0.02): initializer_range=0.02,
):
super(MLPLayer, self).__init__() super(MLPLayer, self).__init__()
d_model = hidden_size d_model = hidden_size
dim_feedforward = intermediate_size dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None bias_attr = None
self.linear0 = nn.Linear(d_model, self.linear0 = nn.Linear(
dim_feedforward, d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
weight_attr, )
bias_attr=bias_attr) self.linear1 = nn.Linear(
self.linear1 = nn.Linear(dim_feedforward, dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
d_model, )
weight_attr,
bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
...@@ -132,16 +134,20 @@ class MLPLayer(nn.Layer): ...@@ -132,16 +134,20 @@ class MLPLayer(nn.Layer):
def train_high_level(fetch): def train_high_level(fetch):
global is_fetch global is_fetch
is_fetch = fetch is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(
intermediate_size=4 * hidden_size, hidden_size=hidden_size,
dropout_ratio=0.1, intermediate_size=4 * hidden_size,
initializer_range=0.02) dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(
beta1=0.9, learning_rate=0.00001,
beta2=0.999, beta1=0.9,
epsilon=1e-08, beta2=0.999,
grad_clip=None) epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy() metric = paddle.metric.Accuracy()
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -153,11 +159,13 @@ def train_high_level(fetch): ...@@ -153,11 +159,13 @@ def train_high_level(fetch):
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
eval_dataset1 = MyDataset(5 * batch_size) eval_dataset1 = MyDataset(5 * batch_size)
history = engine.fit(train_data=train_dataset, history = engine.fit(
epochs=2, train_data=train_dataset,
batch_size=batch_size, epochs=2,
valid_data=eval_dataset1, batch_size=batch_size,
log_freq=1) valid_data=eval_dataset1,
log_freq=1,
)
# eval # eval
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
...@@ -176,16 +184,20 @@ def train_high_level(fetch): ...@@ -176,16 +184,20 @@ def train_high_level(fetch):
def train_low_level(): def train_low_level():
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(
intermediate_size=4 * hidden_size, hidden_size=hidden_size,
dropout_ratio=0.1, intermediate_size=4 * hidden_size,
initializer_range=0.02) dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(
beta1=0.9, learning_rate=0.00001,
beta2=0.999, beta1=0.9,
epsilon=1e-08, beta2=0.999,
grad_clip=None) epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy() metric = paddle.metric.Accuracy()
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -200,18 +212,18 @@ def train_low_level(): ...@@ -200,18 +212,18 @@ def train_low_level():
# Build normal normal dataloader # Build normal normal dataloader
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset, train_dataloader = engine.dataloader(
batch_size=batch_size, train_dataset, batch_size=batch_size, mode="train"
mode="train") )
engine.prepare(mode="train") engine.prepare(mode="train")
for data in train_dataloader: for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train") outs = engine.run(data, feed=feed_dict, mode="train")
# eval # eval
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2, eval_dataloader = engine.dataloader(
batch_size=batch_size, eval_dataset2, batch_size=batch_size, mode="eval"
mode="eval") )
engine.prepare(mode="eval") engine.prepare(mode="eval")
for data in eval_dataloader: for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict, mode="eval") outs = engine.run(data, feed=feed_dict, mode="eval")
...@@ -234,9 +246,9 @@ def train_low_level(): ...@@ -234,9 +246,9 @@ def train_low_level():
# Build dataloader from generator # Build dataloader from generator
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader_from_generator(train_dataset, train_dataloader = engine.dataloader_from_generator(
batch_size=batch_size, train_dataset, batch_size=batch_size, mode="train"
mode="train") )
engine.prepare(mode="train") engine.prepare(mode="train")
for data in train_dataloader: for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train") outs = engine.run(data, feed=feed_dict, mode="train")
...@@ -244,17 +256,18 @@ def train_low_level(): ...@@ -244,17 +256,18 @@ def train_low_level():
# eval # eval
engine.to_mode("eval") engine.to_mode("eval")
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader_from_generator(eval_dataset2, eval_dataloader = engine.dataloader_from_generator(
batch_size=batch_size) eval_dataset2, batch_size=batch_size
)
engine.prepare() engine.prepare()
for data in eval_dataloader: for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict) outs = engine.run(data, feed=feed_dict)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader_from_generator(test_dataset, predict_dataloader = engine.dataloader_from_generator(
batch_size=batch_size, test_dataset, batch_size=batch_size, mode="predict"
mode="predict") )
engine.prepare(mode="predict") engine.prepare(mode="predict")
for data in predict_dataloader: for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict, mode="predict") outs = engine.run(data, feed=feed_dict, mode="predict")
...@@ -268,16 +281,20 @@ def train_low_level(): ...@@ -268,16 +281,20 @@ def train_low_level():
def train_builtin_data_vars(): def train_builtin_data_vars():
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(
intermediate_size=4 * hidden_size, hidden_size=hidden_size,
dropout_ratio=0.1, intermediate_size=4 * hidden_size,
initializer_range=0.02) dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(
beta1=0.9, learning_rate=0.00001,
beta2=0.999, beta1=0.9,
epsilon=1e-08, beta2=0.999,
grad_clip=None) epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy() metric = paddle.metric.Accuracy()
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -295,9 +312,9 @@ def train_builtin_data_vars(): ...@@ -295,9 +312,9 @@ def train_builtin_data_vars():
with static.program_guard(engine.main_program, engine.startup_program): with static.program_guard(engine.main_program, engine.startup_program):
feed_list = engine.inputs + engine.labels feed_list = engine.inputs + engine.labels
print(feed_list) print(feed_list)
loader = paddle.io.DataLoader.from_generator(feed_list=feed_list, loader = paddle.io.DataLoader.from_generator(
capacity=4 * batch_size, feed_list=feed_list, capacity=4 * batch_size, iterable=False
iterable=False) )
places = static.cuda_places() places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places) loader.set_batch_generator(batch_generator_creator(), places=places)
...@@ -308,36 +325,40 @@ def train_builtin_data_vars(): ...@@ -308,36 +325,40 @@ def train_builtin_data_vars():
while True: while True:
engine.run() engine.run()
except paddle.fluid.core.EOFException: except paddle.fluid.core.EOFException:
loader.reset( loader.reset() # call DataLoader.reset() after catching EOFException
) # call DataLoader.reset() after catching EOFException
def train_non_builtin_data_vars(): def train_non_builtin_data_vars():
main_program = static.Program() main_program = static.Program()
startup_program = static.Program() startup_program = static.Program()
with static.program_guard(main_program, with static.program_guard(
startup_program), utils.unique_name.guard(): main_program, startup_program
input = static.data(name="input", ), utils.unique_name.guard():
shape=[batch_size, image_size], input = static.data(
dtype='float32') name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64') label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label], loader = paddle.io.DataLoader.from_generator(
capacity=4 * batch_size, feed_list=[input, label], capacity=4 * batch_size, iterable=False
iterable=False) )
places = static.cuda_places() places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places) loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(
intermediate_size=4 * hidden_size, hidden_size=hidden_size,
dropout_ratio=0.1, intermediate_size=4 * hidden_size,
initializer_range=0.02) dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(
beta1=0.9, learning_rate=0.00001,
beta2=0.999, beta1=0.9,
epsilon=1e-08, beta2=0.999,
grad_clip=None) epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy() metric = paddle.metric.Accuracy()
predict = mlp(input) predict = mlp(input)
loss_var = loss(predict, label) loss_var = loss(predict, label)
...@@ -345,53 +366,109 @@ def train_non_builtin_data_vars(): ...@@ -345,53 +366,109 @@ def train_non_builtin_data_vars():
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var, engine = auto.Engine(
optimizer=optimizer, loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
metrics=metric, )
strategy=strategy)
# train # train
engine.to_mode("train") engine.to_mode("train")
engine.prepare(inputs=[input], engine.prepare(
labels=[label], inputs=[input],
main_program=main_program, labels=[label],
startup_program=startup_program) main_program=main_program,
startup_program=startup_program,
)
for _ in range(epoch_num): for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts loader.start() # call DataLoader.start() before each epoch starts
try: try:
while True: while True:
engine.run() engine.run()
except paddle.fluid.core.EOFException: except paddle.fluid.core.EOFException:
loader.reset( loader.reset() # call DataLoader.reset() after catching EOFException
) # call DataLoader.reset() after catching EOFException
def get_cost(): def get_cost():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(
main_program, startup_program
), utils.unique_name.guard():
input = static.data(
name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(
feed_list=[input, label], capacity=4 * batch_size, iterable=False
)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(
loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
)
engine.prepare(
main_program=main_program,
startup_program=startup_program,
inputs=[input],
labels=[label],
mode="train",
)
engine.cost()
def get_cost_by_default_program():
main_program = static.default_main_program() main_program = static.default_main_program()
startup_program = static.default_startup_program() startup_program = static.default_startup_program()
with static.program_guard(main_program, with static.program_guard(
startup_program), utils.unique_name.guard(): main_program, startup_program
input = static.data(name="input", ), utils.unique_name.guard():
shape=[batch_size, image_size], input = static.data(
dtype='float32') name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64') label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label], loader = paddle.io.DataLoader.from_generator(
capacity=4 * batch_size, feed_list=[input, label], capacity=4 * batch_size, iterable=False
iterable=False) )
places = static.cuda_places() places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places) loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(
intermediate_size=4 * hidden_size, hidden_size=hidden_size,
dropout_ratio=0.1, intermediate_size=4 * hidden_size,
initializer_range=0.02) dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(
beta1=0.9, learning_rate=0.00001,
beta2=0.999, beta1=0.9,
epsilon=1e-08, beta2=0.999,
grad_clip=None) epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy() metric = paddle.metric.Accuracy()
predict = mlp(input) predict = mlp(input)
loss_var = loss(predict, label) loss_var = loss(predict, label)
...@@ -399,24 +476,27 @@ def get_cost(): ...@@ -399,24 +476,27 @@ def get_cost():
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var, engine = auto.Engine(
optimizer=optimizer, loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
metrics=metric, )
strategy=strategy) engine.cost(mode="train")
engine.cost()
def get_cost_by_spec(): def get_cost_by_spec():
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(
intermediate_size=4 * hidden_size, hidden_size=hidden_size,
dropout_ratio=0.1, intermediate_size=4 * hidden_size,
initializer_range=0.02) dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001, optimizer = paddle.optimizer.Adam(
beta1=0.9, learning_rate=0.00001,
beta2=0.999, beta1=0.9,
epsilon=1e-08, beta2=0.999,
grad_clip=None) epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy() metric = paddle.metric.Accuracy()
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -436,4 +516,5 @@ if __name__ == "__main__": ...@@ -436,4 +516,5 @@ if __name__ == "__main__":
train_builtin_data_vars() train_builtin_data_vars()
train_non_builtin_data_vars() train_non_builtin_data_vars()
get_cost() get_cost()
get_cost_by_default_program()
get_cost_by_spec() get_cost_by_spec()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static as static
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset
from paddle.distributed.fleet import auto
paddle.enable_static()
epoch_num = 1
batch_size = 2
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10
is_fetch = True
is_feed = True
my_feed_vars = []
class TrainDataset(Dataset):
def __init__(self, num_samples):
super(TrainDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
def __len__(self):
return self.num_samples
class TestDataset(Dataset):
def __init__(self, num_samples):
super(TestDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
if is_feed:
my_feed_vars.append((out, out.shape))
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
if is_feed:
my_feed_vars.append((out, out.shape))
if is_fetch:
auto.fetch(out, "my_fetch", logging=True)
return out
class TestEngineErrorRaise(unittest.TestCase):
def setUp(self):
class NoSupportData1:
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
class NoSupportData2(TrainDataset):
def __getitem__(self, index):
input = [
list(np.random.uniform(size=image_size).astype("float32"))
]
label = [np.random.randint(0, class_num - 1, dtype="int64")]
return input, label
class NoSupportData3:
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
class NoSupportData4(TestDataset):
def __getitem__(self, index):
input = [
list(np.random.uniform(size=image_size).astype("float32"))
]
return input
self.no_support_data_1 = NoSupportData1()
self.no_support_data_2 = NoSupportData2(10)
self.no_support_data_3 = NoSupportData3()
self.no_support_data_4 = NoSupportData4(10)
def test_Engine(self):
with self.assertRaises(TypeError):
auto.Engine(model=paddle.static.Program())
with self.assertRaises(TypeError):
auto.Engine(loss="CrossEntropyLoss")
with self.assertRaises(TypeError):
auto.Engine(optimizer="adam")
with self.assertRaises(TypeError):
auto.Engine(metrics=["acc"])
with self.assertRaises(TypeError):
auto.Engine(cluster="cluster")
with self.assertRaises(TypeError):
auto.Engine(strategy="strategy")
def test_fit(self):
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
optimizer=paddle.optimizer.AdamW(0.00001),
)
engine.fit(train_data=self.no_support_data_1)
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
optimizer=paddle.optimizer.AdamW(0.00001),
)
engine.fit(train_data=self.no_support_data_2)
def test_evaluate(self):
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
engine.evaluate(valid_data=self.no_support_data_3)
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
engine.evaluate(
valid_data=self.no_support_data_4, valid_sample_split=1
)
def test_predict(self):
with self.assertRaises(TypeError):
engine = auto.Engine(model=MLPLayer())
engine.predict(
test_data=self.no_support_data_3, test_sample_split=1
)
with self.assertRaises(TypeError):
engine = auto.Engine(model=MLPLayer())
engine.predict(
test_data=self.no_support_data_4, test_sample_split=1
)
def build_program(self):
main_prog = static.Program()
startup_prog = static.Program()
with static.program_guard(main_prog, startup_prog):
input = static.data(
name="input",
shape=[batch_size // 2, image_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size // 2, 1], dtype='int64'
)
mlp = MLPLayer()
loss = paddle.nn.CrossEntropyLoss()
predict = mlp(input)
loss_var = loss(predict, label)
return main_prog, startup_prog, input, label, loss_var
def test_prepare(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.prepare()
with self.assertRaises(AssertionError):
engine = auto.Engine(model=MLPLayer())
engine.prepare(mode="train")
with self.assertRaises(TypeError):
input = static.data(
name="input",
shape=[batch_size / 2, image_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size / 2, 1], dtype='int64'
)
engine = auto.Engine(model=MLPLayer())
engine.prepare(inputs_spec=input, labels_spec=label, mode="eval")
input_spec = static.InputSpec(
shape=[batch_size, image_size], dtype="float32", name="input"
)
label_spec = static.InputSpec(
shape=[batch_size, image_size], dtype="float32", name="input"
)
(
main_prog,
startup_prog,
input_var,
label_var,
loss_var,
) = self.build_program()
with self.assertRaises(TypeError):
engine = auto.Engine(loss=loss_var)
engine.prepare(
inputs=input_spec,
labels=label_spec,
main_program=main_prog,
startup_program=startup_prog,
mode="eval",
)
with self.assertRaises(AssertionError):
engine = auto.Engine(loss=loss_var)
engine.prepare(
inputs_spec=[input_spec, input_spec],
labels_spec=[label_spec, label_spec],
inputs=input_var,
labels=label_var,
main_program=main_prog,
startup_program=startup_prog,
mode="predict",
)
def test_cost(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.cost(mode="predict")
class TestEngineDynamicErrorRaise(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def tearDown(self):
paddle.enable_static()
def test_cost(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.cost(mode="predict")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册