diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index c5bdc85e1b5b10ec6d0a265f74a757fe22334a04..1217a0b4d0bf7b3fac53e51e97f91191fa4563fc 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -20,9 +20,28 @@ class AdamOpCost(CompOpCost): OP_TYPE = "adam" def __init__(self, op=None, op_desc=None, cluster=None): - super(AdamOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(AdamOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + +@register_op_cost +class ArgsortOpCost(CompOpCost): + OP_TYPE = "argsort" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(ArgsortOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -39,9 +58,9 @@ class AssignOpCost(CompOpCost): OP_TYPE = "assign" def __init__(self, op=None, op_desc=None, cluster=None): - super(AssignOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(AssignOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -58,9 +77,9 @@ class AssignValueOpCost(CompOpCost): OP_TYPE = "assign_value" def __init__(self, op=None, op_desc=None, cluster=None): - super(AssignValueOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(AssignValueOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -77,9 +96,9 @@ class BeamSearchOpCost(CompOpCost): OP_TYPE = "beam_search" def __init__(self, op=None, op_desc=None, cluster=None): - super(BeamSearchOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(BeamSearchOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -96,9 +115,9 @@ class BeamSearchDecodeOpCost(CompOpCost): OP_TYPE = "beam_search_decode" def __init__(self, op=None, op_desc=None, cluster=None): - super(BeamSearchDecodeOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(BeamSearchDecodeOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -115,9 +134,9 @@ class CastOpCost(CompOpCost): OP_TYPE = "cast" def __init__(self, op=None, op_desc=None, cluster=None): - super(CastOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(CastOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -134,9 +153,9 @@ class ConcatOpCost(CompOpCost): OP_TYPE = "concat" def __init__(self, op=None, op_desc=None, cluster=None): - super(ConcatOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ConcatOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -153,9 +172,9 @@ class DropoutOpCost(CompOpCost): OP_TYPE = "dropout" def __init__(self, op=None, op_desc=None, cluster=None): - super(DropoutOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(DropoutOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -172,9 +191,9 @@ class DropoutGradOpCost(CompOpCost): OP_TYPE = "dropout_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(DropoutGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(DropoutGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -191,9 +210,9 @@ class ElementwiseAddOpCost(CompOpCost): OP_TYPE = "elementwise_add" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseAddOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseAddOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -210,9 +229,9 @@ class ElementwiseAddGradOpCost(CompOpCost): OP_TYPE = "elementwise_add_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseAddGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseAddGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -229,9 +248,9 @@ class ElementwiseDivOpCost(CompOpCost): OP_TYPE = "elementwise_div" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseDivOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseDivOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -248,9 +267,9 @@ class ElementwiseDivGradOpCost(CompOpCost): OP_TYPE = "elementwise_div_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseDivGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseDivGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -267,9 +286,9 @@ class ElementwiseMulOpCost(CompOpCost): OP_TYPE = "elementwise_mul" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseMulOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseMulOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -286,9 +305,9 @@ class ElementwiseMulGradOpCost(CompOpCost): OP_TYPE = "elementwise_mul_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseMulGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseMulGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -305,9 +324,9 @@ class ElementwiseSubOpCost(CompOpCost): OP_TYPE = "elementwise_sub" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseSubOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseSubOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -324,9 +343,28 @@ class ElementwiseSubGradOpCost(CompOpCost): OP_TYPE = "elementwise_sub_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(ElementwiseSubGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ElementwiseSubGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + +@register_op_cost +class EqualOpCost(CompOpCost): + OP_TYPE = "equal" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(EqualOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -343,9 +381,9 @@ class EmbeddingOpCost(CompOpCost): OP_TYPE = "c_embedding" def __init__(self, op=None, op_desc=None, cluster=None): - super(EmbeddingOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(EmbeddingOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -362,9 +400,9 @@ class EmbeddingGradOpCost(CompOpCost): OP_TYPE = "c_embedding_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(EmbeddingGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(EmbeddingGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -381,9 +419,9 @@ class FillConstantOpCost(CompOpCost): OP_TYPE = "fill_constant" def __init__(self, op=None, op_desc=None, cluster=None): - super(FillConstantOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(FillConstantOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -400,9 +438,9 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): OP_TYPE = "fill_constant_batch_size_like" def __init__(self, op=None, op_desc=None, cluster=None): - super(FillConstantBatchSizeLikeOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(FillConstantBatchSizeLikeOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -419,8 +457,9 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): OP_TYPE = "fused_softmax_mask_upper_triangle" def __init__(self, op=None, op_desc=None, cluster=None): - super(FusedSoftmaxMaskUpperTriangleOpCost, - self).__init__(op=op, op_desc=op_desc, cluster=cluster) + super(FusedSoftmaxMaskUpperTriangleOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -437,8 +476,9 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): OP_TYPE = "fused_softmax_mask_upper_triangle_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(FusedSoftmaxMaskUpperTriangleGradOpCost, - self).__init__(op=op, op_desc=op_desc, cluster=cluster) + super(FusedSoftmaxMaskUpperTriangleGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -455,9 +495,9 @@ class GatherOpCost(CompOpCost): OP_TYPE = "gather" def __init__(self, op=None, op_desc=None, cluster=None): - super(GatherOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(GatherOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -474,9 +514,9 @@ class GeluOpCost(CompOpCost): OP_TYPE = "gelu" def __init__(self, op=None, op_desc=None, cluster=None): - super(GeluOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(GeluOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -493,9 +533,9 @@ class GeluGradOpCost(CompOpCost): OP_TYPE = "gelu_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(GeluGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(GeluGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -512,9 +552,9 @@ class GreaterEqualOpCost(CompOpCost): OP_TYPE = "greater_equal" def __init__(self, op=None, op_desc=None, cluster=None): - super(GreaterEqualOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(GreaterEqualOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -531,9 +571,9 @@ class IncrementOpCost(CompOpCost): OP_TYPE = "increment" def __init__(self, op=None, op_desc=None, cluster=None): - super(IncrementOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(IncrementOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -546,9 +586,9 @@ class IsEmptyOpCost(CompOpCost): OP_TYPE = "is_empty" def __init__(self, op=None, op_desc=None, cluster=None): - super(IsEmptyOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(IsEmptyOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -561,9 +601,9 @@ class LayerNormOpCost(CompOpCost): OP_TYPE = "layer_norm" def __init__(self, op=None, op_desc=None, cluster=None): - super(LayerNormOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LayerNormOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -580,9 +620,9 @@ class LayerNormGradOpCost(CompOpCost): OP_TYPE = "layer_norm_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(LayerNormGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LayerNormGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -599,9 +639,9 @@ class LessThanOpCost(CompOpCost): OP_TYPE = "less_than" def __init__(self, op=None, op_desc=None, cluster=None): - super(LessThanOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LessThanOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -618,9 +658,9 @@ class LogicalNotOpCost(CompOpCost): OP_TYPE = "logical_not" def __init__(self, op=None, op_desc=None, cluster=None): - super(LogicalNotOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LogicalNotOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -637,9 +677,9 @@ class LogicalAndOpCost(CompOpCost): OP_TYPE = "logical_and" def __init__(self, op=None, op_desc=None, cluster=None): - super(LogicalAndOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LogicalAndOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -656,9 +696,9 @@ class LodResetOpCost(CompOpCost): OP_TYPE = "lod_reset" def __init__(self, op=None, op_desc=None, cluster=None): - super(LodResetOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LodResetOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -692,9 +732,9 @@ class LookupTableV2OpCost(CompOpCost): OP_TYPE = "lookup_table_v2" def __init__(self, op=None, op_desc=None, cluster=None): - super(LookupTableV2OpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LookupTableV2OpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -711,9 +751,9 @@ class LookupTableV2GradOpCost(CompOpCost): OP_TYPE = "lookup_table_v2_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(LookupTableV2GradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(LookupTableV2GradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -730,9 +770,9 @@ class MatmulOpCost(CompOpCost): OP_TYPE = "matmul" def __init__(self, op=None, op_desc=None, cluster=None): - super(MatmulOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(MatmulOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -749,9 +789,9 @@ class MatmulGradOpCost(CompOpCost): OP_TYPE = "matmul_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(MatmulGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(MatmulGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -768,9 +808,9 @@ class MatmulV2OpCost(CompOpCost): OP_TYPE = "matmul_v2" def __init__(self, op=None, op_desc=None, cluster=None): - super(MatmulV2OpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(MatmulV2OpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -787,9 +827,9 @@ class MatmulV2GradOpCost(CompOpCost): OP_TYPE = "matmul_v2_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(MatmulV2GradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(MatmulV2GradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -806,9 +846,9 @@ class MemcpyOpCost(CompOpCost): OP_TYPE = "memcpy" def __init__(self, op=None, op_desc=None, cluster=None): - super(MemcpyOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(MemcpyOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -842,9 +882,9 @@ class MulGradOpCost(CompOpCost): OP_TYPE = "mul_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(MulGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(MulGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -861,9 +901,9 @@ class OneHotOpCost(CompOpCost): OP_TYPE = "one_hot" def __init__(self, op=None, op_desc=None, cluster=None): - super(OneHotOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(OneHotOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -880,9 +920,9 @@ class ReadFromArrayOpCost(CompOpCost): OP_TYPE = "read_from_array" def __init__(self, op=None, op_desc=None, cluster=None): - super(ReadFromArrayOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ReadFromArrayOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -899,9 +939,9 @@ class ReduceSumOpCost(CompOpCost): OP_TYPE = "reduce_sum" def __init__(self, op=None, op_desc=None, cluster=None): - super(ReduceSumOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ReduceSumOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -918,9 +958,9 @@ class ReduceSumGradOpCost(CompOpCost): OP_TYPE = "reduce_sum_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(ReduceSumGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ReduceSumGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -937,9 +977,9 @@ class Reshape2OpCost(CompOpCost): OP_TYPE = "reshape2" def __init__(self, op=None, op_desc=None, cluster=None): - super(Reshape2OpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(Reshape2OpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -956,9 +996,9 @@ class Reshape2GradOpCost(CompOpCost): OP_TYPE = "reshape2_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(Reshape2GradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(Reshape2GradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -975,9 +1015,9 @@ class ReduceMeanOpCost(CompOpCost): OP_TYPE = "reduce_mean" def __init__(self, op=None, op_desc=None, cluster=None): - super(ReduceMeanOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ReduceMeanOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -994,9 +1034,9 @@ class ReduceMeanGradOpCost(CompOpCost): OP_TYPE = "reduce_mean_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(ReduceMeanGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ReduceMeanGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1013,9 +1053,9 @@ class SamplingIdOpCost(CompOpCost): OP_TYPE = "sampling_id" def __init__(self, op=None, op_desc=None, cluster=None): - super(SamplingIdOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SamplingIdOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1032,9 +1072,9 @@ class ScaleOpCost(CompOpCost): OP_TYPE = "scale" def __init__(self, op=None, op_desc=None, cluster=None): - super(ScaleOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(ScaleOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1051,9 +1091,9 @@ class SliceOpCost(CompOpCost): OP_TYPE = "slice" def __init__(self, op=None, op_desc=None, cluster=None): - super(SliceOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SliceOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1070,9 +1110,9 @@ class SoftmaxOpCost(CompOpCost): OP_TYPE = "softmax" def __init__(self, op=None, op_desc=None, cluster=None): - super(SoftmaxOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SoftmaxOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1089,9 +1129,9 @@ class SoftmaxGradOpCost(CompOpCost): OP_TYPE = "softmax_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(SoftmaxGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SoftmaxGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1108,9 +1148,9 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost): OP_TYPE = "softmax_with_cross_entropy" def __init__(self, op=None, op_desc=None, cluster=None): - super(SoftmaxWithCrossEntropyOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SoftmaxWithCrossEntropyOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1127,9 +1167,9 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost): OP_TYPE = "softmax_with_cross_entropy_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(SoftmaxWithCrossEntropyGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SoftmaxWithCrossEntropyGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1146,9 +1186,9 @@ class SplitOpCost(CompOpCost): OP_TYPE = "split" def __init__(self, op=None, op_desc=None, cluster=None): - super(SplitOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SplitOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1165,9 +1205,9 @@ class Squeeze2OpCost(CompOpCost): OP_TYPE = "squeeze2" def __init__(self, op=None, op_desc=None, cluster=None): - super(Squeeze2OpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(Squeeze2OpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1184,9 +1224,9 @@ class SquareOpCost(CompOpCost): OP_TYPE = "square" def __init__(self, op=None, op_desc=None, cluster=None): - super(SquareOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SquareOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1203,9 +1243,9 @@ class SquareGradOpCost(CompOpCost): OP_TYPE = "square_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(SquareGradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(SquareGradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1239,9 +1279,9 @@ class TopKOpCost(CompOpCost): OP_TYPE = "top_k" def __init__(self, op=None, op_desc=None, cluster=None): - super(TopKOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(TopKOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1258,9 +1298,9 @@ class Transpose2OpCost(CompOpCost): OP_TYPE = "transpose2" def __init__(self, op=None, op_desc=None, cluster=None): - super(Transpose2OpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(Transpose2OpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1277,9 +1317,9 @@ class Transpose2GradOpCost(CompOpCost): OP_TYPE = "transpose2_grad" def __init__(self, op=None, op_desc=None, cluster=None): - super(Transpose2GradOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(Transpose2GradOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1296,9 +1336,9 @@ class Unsqueeze2OpCost(CompOpCost): OP_TYPE = "unsqueeze2" def __init__(self, op=None, op_desc=None, cluster=None): - super(Unsqueeze2OpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(Unsqueeze2OpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): @@ -1315,9 +1355,9 @@ class WriteToArrayOpCost(CompOpCost): OP_TYPE = "write_to_array" def __init__(self, op=None, op_desc=None, cluster=None): - super(WriteToArrayOpCost, self).__init__(op=op, - op_desc=op_desc, - cluster=cluster) + super(WriteToArrayOpCost, self).__init__( + op=op, op_desc=op_desc, cluster=cluster + ) # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided def calc_flops(self): diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py index 94dc5287910f215bbacfce83302b6d9e1e2caf1c..a3d737769d01c8ac8019ae7f2891670ceb6389bf 100644 --- a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor class CostEstimator: _sepical_op_type = ["fused_attention", "fused_feedforward"] - def __init__(self, - program, - cluster, - mode="modeling", - rank=None, - loop_count=10): + def __init__( + self, program, cluster, mode="modeling", rank=None, loop_count=10 + ): self._program = program self._cluster = cluster self._check_mode(mode) @@ -41,7 +38,8 @@ class CostEstimator: self._loop_count = loop_count self._global_cost = Cost() self._local_cost_mapping = {} - self._detailed_cost = OrderedDict( + self._detailed_cost = ( + OrderedDict() ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} self._bubble_time_mapping = {} self._ordered_ops = [] @@ -106,7 +104,8 @@ class CostEstimator: def _check_mode(self, mode): if mode not in ["modeling", "profiling"]: raise ValueError( - "Just support modeling and profiling, but got {}".format(mode)) + "Just support modeling and profiling, but got {}".format(mode) + ) def _is_special_var_name(self, var_name): special_var_name = ["lod_tensor_blocking_queue_0"] @@ -116,6 +115,7 @@ class CostEstimator: def _estimate_core(self, dist_context, resharder, block): from ..reshard import get_var_with_recursion + ops = block.ops loop_count = None if block.desc.id != self.program.global_block().desc.id: @@ -132,8 +132,9 @@ class CostEstimator: if int(op.attr('op_role')) == int(OpRole.Optimize): continue if op.type in [ - "create_py_reader", "create_double_buffer_reader", - "read" + "create_py_reader", + "create_double_buffer_reader", + "read", ]: continue @@ -172,14 +173,16 @@ class CostEstimator: max_time = rank_cost.time for rank in group_ranks: - self.local_cost( - rank).time = max_time + cost.time + self.local_cost(rank).time = ( + max_time + cost.time + ) if rank not in self._bubble_time_mapping: self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] += ( - max_time - cost_time[rank]) + max_time - cost_time[rank] + ) for rank in local_comp_cost: for comp_cost in local_comp_cost[rank]: @@ -191,15 +194,19 @@ class CostEstimator: processes = op_dist_attr.process_mesh.processes container = get_distributed_operator_impl_container( - op_dist_attr.impl_type) + op_dist_attr.impl_type + ) dist_impl = container.impls[op_dist_attr.impl_idx] - dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, - dist_context, self.cluster) + dist_op_cost = dist_impl.calc_cost( + op.attr('op_role'), dist_op, dist_context, self.cluster + ) detail["dist_op_cost"] = dist_op_cost if dist_op_cost is None: - assert dist_op.serial_op.type in CostEstimator._sepical_op_type + assert ( + dist_op.serial_op.type in CostEstimator._sepical_op_type + ) continue for item in dist_op_cost: if isinstance(item, list): @@ -217,12 +224,14 @@ class CostEstimator: if max_time < rank_cost.time: max_time = rank_cost.time for rank in group_ranks: - self.local_cost( - rank).time = max_time + comm_op_cost.time + self.local_cost(rank).time = ( + max_time + comm_op_cost.time + ) if rank not in self._bubble_time_mapping: self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] += ( - max_time - cost_time[rank]) + max_time - cost_time[rank] + ) elif isinstance(item, dict): # Op just one for rank in processes: @@ -247,8 +256,11 @@ class CostEstimator: dtype_factor = 8 elif dtype == paddle.float32 or dtype == paddle.int32: dtype_factor = 4 - elif dtype == paddle.float16 or dtype == paddle.bfloat16 \ - or dtype == paddle.int16: + elif ( + dtype == paddle.float16 + or dtype == paddle.bfloat16 + or dtype == paddle.int16 + ): dtype_factor = 2 elif dtype == paddle.int8 or dtype == paddle.uint8: dtype_factor = 1 @@ -270,8 +282,9 @@ class CostEstimator: memories = {} self.max_memories = {} - var_info = { - } # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} + var_info = ( + {} + ) # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} for block in self.program.blocks: for op in block.ops: @@ -280,18 +293,22 @@ class CostEstimator: for op_id, op in self._ordered_ops: if op.type in [ - "create_py_reader", "create_double_buffer_reader", "read" + "create_py_reader", + "create_double_buffer_reader", + "read", ]: continue dist_op = dist_context.get_dist_op_for_program(op) process_mesh = dist_op.dist_attr.process_mesh for var_name in op.input_arg_names: input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( - var_name) + var_name + ) if var_name not in var_info: var_info[var_name] = {} - key = _convert_pm_and_dm_to_str(process_mesh, - input_dims_mapping) + key = _convert_pm_and_dm_to_str( + process_mesh, input_dims_mapping + ) if key not in var_info[var_name]: var_info[var_name][key] = {} # It is even partition now @@ -300,21 +317,27 @@ class CostEstimator: global_sizes = var.shape dtype = var.dtype sizes = DistributedTensor.get_local_sizes( - global_sizes, input_dims_mapping, process_mesh.topology, - process_mesh.processes) + global_sizes, + input_dims_mapping, + process_mesh.topology, + process_mesh.processes, + ) var_info[var_name][key]["memory"] = self._calculate_bytes( - sizes, dtype) + sizes, dtype + ) if "position" not in var_info[var_name][key]: var_info[var_name][key]["position"] = [] var_info[var_name][key]["position"].append(op_id) for var_name in op.output_arg_names: output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( - var_name) + var_name + ) if var_name not in var_info: var_info[var_name] = {} - key = _convert_pm_and_dm_to_str(process_mesh, - output_dims_mapping) + key = _convert_pm_and_dm_to_str( + process_mesh, output_dims_mapping + ) if key not in var_info[var_name]: var_info[var_name][key] = {} if "memory" not in var_info[var_name][key]: @@ -322,10 +345,14 @@ class CostEstimator: global_sizes = var.shape dtype = var.dtype sizes = DistributedTensor.get_local_sizes( - global_sizes, output_dims_mapping, - process_mesh.topology, process_mesh.processes) + global_sizes, + output_dims_mapping, + process_mesh.topology, + process_mesh.processes, + ) var_info[var_name][key]["memory"] = self._calculate_bytes( - sizes, dtype) + sizes, dtype + ) if "position" not in var_info[var_name][key]: var_info[var_name][key]["position"] = [] var_info[var_name][key]["position"].append(op_id) @@ -333,7 +360,9 @@ class CostEstimator: has_used_vars = set() for op_id, op in self._ordered_ops: if op.type in [ - "create_py_reader", "create_double_buffer_reader", "read" + "create_py_reader", + "create_double_buffer_reader", + "read", ]: continue can_free_memories = {} @@ -342,9 +371,11 @@ class CostEstimator: process_mesh = dist_op.dist_attr.process_mesh for var_name in op.input_arg_names: input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( - var_name) - key = _convert_pm_and_dm_to_str(process_mesh, - input_dims_mapping) + var_name + ) + key = _convert_pm_and_dm_to_str( + process_mesh, input_dims_mapping + ) has_used_var = var_name + key var = dist_op.get_serial_input(var_name) # Not used @@ -364,13 +395,16 @@ class CostEstimator: if process not in can_free_memories: can_free_memories[process] = 0 can_free_memories[process] += var_info[ - var_name][key]["memory"] + var_name + ][key]["memory"] for var_name in op.output_arg_names: output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( - var_name) - key = _convert_pm_and_dm_to_str(process_mesh, - output_dims_mapping) + var_name + ) + key = _convert_pm_and_dm_to_str( + process_mesh, output_dims_mapping + ) has_used_var = var_name + key var = dist_op.get_serial_output(var_name) # Not used @@ -390,7 +424,8 @@ class CostEstimator: if process not in can_free_memories: can_free_memories[process] = 0 can_free_memories[process] += var_info[ - var_name][key]["memory"] + var_name + ][key]["memory"] # Calc peak memory for process in memories: @@ -414,8 +449,12 @@ class CostEstimator: def estimate(self, dist_context, resharder=None): self.prepare() from ..reshard import Resharder - resharder = Resharder(self.program, None, self.rank, dist_context, - []) if resharder is None else resharder + + resharder = ( + Resharder(self.program, None, self.rank, dist_context, []) + if resharder is None + else resharder + ) block = self.program.global_block() self._estimate_core(dist_context, resharder, block) @@ -447,7 +486,7 @@ class CostEstimator: memories = [ int(item // 1e6) for item in list(self.max_memories.values()) ] - for memory in (memories + header): + for memory in memories + header: if len(str(memory)) > max_len: max_len = len(str(memory)) max_len += 4 # for pretty print of center @@ -477,7 +516,7 @@ class CostEstimator: max_len = 0 header = ["Execution Time(ms)", "Max Memory(MiB)"] vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] - for memory in (vals + header): + for memory in vals + header: if len(str(memory)) > max_len: max_len = len(str(memory)) max_len += 4 # for pretty print of center @@ -507,50 +546,73 @@ class CostEstimator: def get_cost_from_engine(engine, mode): from ..utils import to_list - # Construct cost estimator by original main program - serial_main_prog = engine._serial_main_progs[mode].clone( - ) if mode in engine._serial_main_progs else engine._orig_main_prog.clone() + import copy - serial_startup_prog = engine._serial_startup_progs[mode].clone( - ) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone( + # Construct cost estimator by original main program + serial_main_prog = ( + engine._fwd_main_progs[mode].clone() + if mode in engine._fwd_main_progs + else engine._orig_main_prog.clone() ) - losses = to_list( - engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer) - and not callable(engine._loss)) else engine._losses - if mode in engine._dist_contexts: - dist_context = engine._dist_contexts[mode] - completer = engine._planners[mode].completer + serial_startup_prog = ( + engine._serial_startup_progs[mode].clone() + if mode in engine._serial_startup_progs + else engine._orig_startup_prog.clone() + ) + losses = ( + to_list(engine._loss) + if ( + not isinstance(engine._loss, paddle.nn.Layer) + and not callable(engine._loss) + ) + else engine._losses + ) + serial_optimizer = copy.deepcopy(engine._orig_optimizer) + if mode in engine._fwd_dist_contexts: + dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode]) else: - from ..completion import Completer from ..dist_context import DistributedContext - dist_context = DistributedContext(serial_main_prog, serial_startup_prog, - engine._optimizer, losses, {}, - {"loss": losses}, engine._cluster, - engine._strategy) - completer = Completer(dist_context) - completer.complete_forward_annotation() - dist_context.block_state.parse_forward_blocks( - dist_context.serial_main_program) + + dist_context = DistributedContext( + serial_main_prog, + serial_startup_prog, + serial_optimizer, + losses, + {}, + {"loss": losses}, + engine._cluster, + engine._strategy, + ) + from ..completion import Completer + + completer = Completer(dist_context) + completer.complete_forward_annotation() + dist_context.block_state.parse_forward_blocks( + dist_context.serial_main_program + ) if mode == "eval" or mode == "predict": cost_estimator = CostEstimator(serial_main_prog, engine._cluster) elif mode == "train": from ..parallelizer_v2 import Parallelizer + # Get serial main program with backward - serial_optimizer = engine._optimizer parallelizer = Parallelizer(mode, completer, dist_context) # Generate backward loss_name = dist_context.serial_loss.name serial_loss = serial_main_prog.global_block()._var_recursive(loss_name) - params_grads = parallelizer._generate_backward(serial_main_prog, - serial_startup_prog, - serial_loss) + params_grads = parallelizer._generate_backward( + serial_main_prog, serial_startup_prog, serial_loss + ) # Generate optimizer optimizer_ops = parallelizer._generate_optimizer( - serial_main_prog, serial_startup_prog, serial_optimizer, - params_grads) + serial_main_prog, + serial_startup_prog, + serial_optimizer, + params_grads, + ) cost_estimator = CostEstimator(serial_main_prog, engine._cluster) # Estimate global_cost and max memory diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index a2e1477f8873c9ec69e72cd3cfccede580695386..d3f1b249d12831fdc3f665bbee8424cc32eb6502 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -13,8 +13,10 @@ # limitations under the License. import os +import copy import logging import random +import numbers import numpy as np from collections import defaultdict @@ -41,16 +43,20 @@ from .planner_v2 import Planner from .parallelizer_v2 import Parallelizer from .dist_op import DistributedOperator from .dist_saver import DistributedSaver -from .dist_loader import DistributedDataLoaderFromGenerator, DistributedDataLoader -from .utils import to_list, get_dist_attr, get_lr +from .dist_loader import ( + DistributedDataLoaderFromGenerator, + DistributedDataLoader, +) from .process_group import new_process_group, get_all_process_groups from .dist_context import DistributedContext, get_default_distributed_context from .strategy import Strategy from .interface import CollectionNames, get_collection -from ..utils.log_utils import get_logger -from .utils import initialize_pg_in_full_mode +from .utils import to_list, get_dist_attr, get_lr, validate_opt +from .utils import initialize_pg_in_full_mode, get_input_split_info from .cost.estimate_cost import get_cost_from_engine +from ..utils.log_utils import get_logger + class Engine: """ @@ -115,35 +121,55 @@ class Engine: """ - def __init__(self, - model=None, - loss=None, - optimizer=None, - metrics=None, - cluster=None, - strategy=None): - - if model and not isinstance(model, - paddle.nn.Layer) and not callable(model): + def __init__( + self, + model=None, + loss=None, + optimizer=None, + metrics=None, + cluster=None, + strategy=None, + ): + + if ( + model + and not isinstance(model, paddle.nn.Layer) + and not callable(model) + ): raise TypeError( "'model must be sub classes of `paddle.nn.Layer` or any callable function." ) self._model = model + + if ( + loss + and not isinstance(loss, (paddle.nn.Layer, Variable)) + and not callable(loss) + ): + raise TypeError( + "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable." + ) self._loss = loss if optimizer and not isinstance( - optimizer, - (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): + optimizer, + (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer), + ): raise TypeError( "'optimizer' must be object of class `paddle.optimizer.Optimizer`" - " or `paddle.fluid.optimizer.Optimizer`.") - self._optimizer = self._validate_opt(optimizer) + " or `paddle.fluid.optimizer.Optimizer`." + ) + self._optimizer = validate_opt(optimizer) + self._orig_optimizer = copy.deepcopy(self._optimizer) metrics = metrics or [] for metric in to_list(metrics): - assert isinstance(metric, Metric), \ - "{} is not sub class of Metric".format( - metric.__class__.__name__) + if metric and not isinstance(metric, Metric): + raise TypeError( + "{} is not sub class of Metric".format( + metric.__class__.__name__ + ) + ) self._metrics = to_list(metrics) if cluster and not isinstance(cluster, Cluster): @@ -158,9 +184,11 @@ class Engine: ) self._strategy = strategy or Strategy() + self._logger = get_logger(logging.INFO) if os.getenv("POD_NAME"): - print("Distribute training by paddle.distributed.launch", - flush=True) + self._logger.info( + "Distribute training by paddle.distributed.launch" + ) fleet.init(is_collective=True) self._executor = None @@ -168,12 +196,12 @@ class Engine: self._nranks = paddle.distributed.get_world_size() self._saver = DistributedSaver() - self._logger = get_logger(logging.INFO) - self._orig_main_prog = static.default_main_program() self._orig_startup_prog = static.default_startup_program() self._orig_dist_context = get_default_distributed_context() self._dist_contexts = {} + self._fwd_main_progs = {} + self._fwd_dist_contexts = {} self._serial_main_progs = {} self._serial_startup_progs = {} self._dist_main_progs = defaultdict(dict) # dist main programs @@ -185,19 +213,20 @@ class Engine: self._has_prepared_reader = { "train": False, "eval": False, - "predict": False + "predict": False, } self._inputs_spec = [] self._labels_spec = [] self._inputs = [] self._labels = [] + self._losses = [] + self._mode = None self._skip_build = False self._outside_dataloader = False self._planned_mode = None self._dygraph_mode = False self._tuning = self._strategy.tuning - self._losses = None self.history = None @@ -219,9 +248,11 @@ class Engine: inputs = sample[:split] labels = sample[split:] else: - raise ValueError( - "Data should be a Dataset or IterableDatset, but received {}.". - format(type(data).__name__)) + raise TypeError( + "Data should be a Dataset or IterableDatset, but received {}.".format( + type(data).__name__ + ) + ) inputs = to_list(inputs) labels = to_list(labels) @@ -240,14 +271,20 @@ class Engine: else: specs.append(spec.batch(batch_size)) elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)): - _adjust_item_spec(num_shards, spec) spec = InputSpec.from_tensor(item, name) + _adjust_item_spec(num_shards, spec) if batch_size is None: specs.append(spec) else: specs.append(spec.batch(batch_size)) - else: + elif isinstance(item, numbers.Number): specs.append(InputSpec([batch_size], type(item), name)) + else: + raise TypeError( + "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format( + type(item).__name__ + ) + ) if inputs is not None: for i, item in enumerate(inputs): @@ -264,37 +301,41 @@ class Engine: labels_spec = self._validate_spec(labels_spec) return inputs_spec, labels_spec - def _prepare_data_tensor(self, - inputs_spec, - labels_spec, - inputs=None, - labels=None): + def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels): if _non_static_mode() or self._dygraph_mode: - return None, None - inputs_spec = inputs_spec if inputs_spec else [] - labels_spec = labels_spec if labels_spec else [] + raise ValueError("Only support static graph mode.") + if inputs_spec: - assert isinstance(inputs_spec, list), \ - "inputs should be list, but received {}".format(type(inputs_spec)) - if inputs is None: - inputs = [s._create_feed_layer() for s in inputs_spec] - else: - assert isinstance(inputs, list), \ - "inputs should be list, but received {}".format(type(inputs)) - for input_spec, input in zip(inputs_spec, inputs): - if input_spec.shape != input.shape: - input.desc.set_shape(input_spec.shape) + assert isinstance( + inputs_spec, list + ), "inputs should be list, but received {}".format( + type(inputs_spec) + ) + assert isinstance( + inputs, list + ), "inputs should be list, but received {}".format(type(inputs)) + assert len(inputs_spec) == len( + inputs + ), "the number of `inputs_spec` should be equal to `inputs`'s." + for input_spec, input in zip(inputs_spec, inputs): + if input_spec.shape != input.shape: + input.desc.set_shape(input_spec.shape) if labels_spec: - assert isinstance(labels_spec, list), \ - "labels should be list, but received {}".format(type(labels_spec)) - if labels is None: - labels = [s._create_feed_layer() for s in labels_spec] - else: - assert isinstance(labels, list), \ - "labels should be list, but received {}".format(type(labels)) - for label_spec, label in zip(labels_spec, labels): - if label_spec.shape != label.shape: - label.desc.set_shape(label_spec.shape) + assert isinstance( + labels_spec, list + ), "labels should be list, but received {}".format( + type(labels_spec) + ) + assert isinstance( + labels, list + ), "labels should be list, but received {}".format(type(labels)) + assert len(labels_spec) == len( + labels + ), "the number of `labels_spec` should be equal to `labels`'s." + for label_spec, label in zip(labels_spec, labels): + if label_spec.shape != label.shape: + label.desc.set_shape(label_spec.shape) + return inputs, labels def _prepare_reader(self): @@ -304,7 +345,9 @@ class Engine: # NOTE: this list may be changed if Paddle changes the existing rules. related_reader_ops = [ - "create_py_reader", "create_double_buffer_reader", "read" + "create_py_reader", + "create_double_buffer_reader", + "read", ] # remove the first three ops if multiple run fit/evaluate/predict if dist_main_block.ops[0].type == 'create_py_reader': @@ -322,9 +365,9 @@ class Engine: for idx in reversed(reader_op_indices): new_op_desc = dist_main_block.desc._prepend_op() new_op_desc.copy_from(dist_main_block.ops[idx].desc) - new_op = Operator(dist_main_block, - new_op_desc, - type=new_op_desc.type()) + new_op = Operator( + dist_main_block, new_op_desc, type=new_op_desc.type() + ) new_reader_ops.append(new_op) dist_op = DistributedOperator(new_op) dist_context.add_dist_op_for_program(dist_op) @@ -355,16 +398,22 @@ class Engine: else: raise ValueError("Unsupported data {}".format(data)) if user_feeds is not None: - assert isinstance(user_feeds, dict), \ - "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__) + assert isinstance( + user_feeds, dict + ), "user_feeds must be a dict, but receive {}".format( + type(user_feeds).__name__ + ) for name, data in user_feeds.items(): feeds[name] = data return feeds def _prepare_fetch(self, user_fetches, mode): if user_fetches is not None: - assert isinstance(user_fetches, list), \ - "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__) + assert isinstance( + user_fetches, list + ), "user_fetches must be a list, but receive {}".format( + type(user_fetches).__name__ + ) fetch_names = [] fetch_indices = [] @@ -396,14 +445,16 @@ class Engine: _process_fetch_group("fetches", var_list) return fetch_names, fetch_indices - def _prepare_logger(self, - outs, - epoch=None, - step=None, - lr=None, - fetch_names=None, - fetch_indices=None, - mode=None): + def _prepare_logger( + self, + outs, + epoch=None, + step=None, + lr=None, + fetch_names=None, + fetch_indices=None, + mode=None, + ): logs = {} if epoch is not None: logs["epoch"] = epoch @@ -468,11 +519,13 @@ class Engine: self._dygraph_mode = True self._logger.info("Building model with 'to_static' method.") - inputs_spec = self._inputs_spec - labels_spec = self._labels_spec if self._labels_spec else [] - self.program_helper = ProgramHelper(self._model, self._loss, - self._metrics, inputs_spec, - labels_spec) + self.program_helper = ProgramHelper( + self._model, + self._loss, + self._metrics, + self._inputs_spec, + self._labels_spec, + ) # build forward main program self.program_helper.build_program(mode) @@ -480,16 +533,12 @@ class Engine: serial_main_prog = self.program_helper.main_program serial_startup_prog = self.program_helper.startup_program - inputs = self.program_helper.input_vars + self._inputs = self.program_helper.input_vars + self._labels = self.program_helper.label_vars outputs = self.program_helper.output_vars - labels = self.program_helper.label_vars - losses = self.program_helper.loss_vars - self._losses = losses + self._losses = self.program_helper.loss_vars metrics = self.program_helper.metric_vars - self._inputs = inputs - self._labels = labels - paddle.enable_static() else: # build program in static mode @@ -498,27 +547,45 @@ class Engine: return outputs = [] - losses = [] metrics = [] - inputs = self._inputs if self._inputs else [] - labels = self._labels if self._labels else [] + self._losses = [] serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() if not self._skip_build: - with static.program_guard(serial_main_prog, serial_startup_prog), \ - utils.unique_name.guard(): - outputs = to_list(self._model(*inputs)) - if mode != "predict" and self._loss: - losses = to_list(self._loss(*(outputs + labels))) - self._losses = losses + with static.program_guard( + serial_main_prog, serial_startup_prog + ), utils.unique_name.guard(): + self._inputs = [ + s._create_feed_layer() for s in self._inputs_spec + ] + self._labels = [ + s._create_feed_layer() for s in self._labels_spec + ] + + outputs = to_list(self._model(*self._inputs)) - if mode != "predict" and (outputs or labels): + if mode != "predict" and self._loss: + assert isinstance( + self._loss, paddle.nn.Layer + ) or callable( + self._loss + ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function." + self._losses = to_list( + self._loss(*(outputs + self._labels)) + ) + + if mode != "predict" and (outputs or self._labels): for metric in self._metrics: metrics.append( - to_list(metric.compute(*(outputs + labels)))) + to_list( + metric.compute(*(outputs + self._labels)) + ) + ) else: - losses = to_list(self._loss) - self.losses = losses + assert isinstance( + self._loss, Variable + ), "the type of `loss` of the Engine arguments should be Variable." + self._losses = to_list(self._loss) default_ctx = get_default_distributed_context() if not default_ctx.has_annotation: @@ -527,12 +594,12 @@ class Engine: new_process_group(list(range(self._nranks))) default_ctx.data_parallel = True - feed_vars = {"inputs": inputs, "labels": labels} + feed_vars = {"inputs": self._inputs, "labels": self._labels} fetch_vars = { "outputs": flatten(outputs), - "loss": losses, - "metrics": metrics + "loss": self._losses, + "metrics": metrics, } if mode != "train": @@ -540,9 +607,27 @@ class Engine: self._set_recompute_ckpts() self._dist_contexts[mode] = DistributedContext( - serial_main_prog, serial_startup_prog, self._optimizer, losses, - feed_vars, fetch_vars, self._cluster, self._strategy) + serial_main_prog, + serial_startup_prog, + self._optimizer, + self._losses, + feed_vars, + fetch_vars, + self._cluster, + self._strategy, + ) + self._fwd_dist_contexts[mode] = DistributedContext( + serial_main_prog, + serial_startup_prog, + self._optimizer, + self._losses, + feed_vars, + fetch_vars, + self._cluster, + self._strategy, + ) self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale + self._fwd_main_progs[mode] = serial_main_prog.clone() def _optimization_tuning(self, mode, dataset, batch_size): if not self._tuning.enable: @@ -558,20 +643,24 @@ class Engine: dataset.dp_rank = self._dp_ranks from .tuner.optimization_tuner import OptimizationTuner - self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(), - self._dist_contexts[mode], - dataset, - self._inputs_spec, - self._labels_spec, - batch_size=batch_size, - rank=self._cur_rank) + + self._optimization_tuner = OptimizationTuner( + self._tuning.to_dict(), + self._dist_contexts[mode], + dataset, + self._inputs_spec, + self._labels_spec, + batch_size=batch_size, + rank=self._cur_rank, + ) self._optimization_tuner.tune() if self._tuning.run_after_tuning: # update the strategy self._dist_contexts[ - mode]._strategy = self._optimization_tuner.get_best_config() + mode + ]._strategy = self._optimization_tuner.get_best_config() def _plan(self, mode): if self._planned_mode is None: @@ -595,8 +684,9 @@ class Engine: self._dp_world_sizes = [] self._dp_ranks = [] for feed_var in feed_list: - dp_world_size, dp_rank = self._get_input_split_info( - feed_var, self._dist_contexts[mode]) + dp_world_size, dp_rank = get_input_split_info( + self._cur_rank, feed_var, self._dist_contexts[mode] + ) self._dp_world_sizes.append(dp_world_size) self._dp_ranks.append(dp_rank) @@ -604,8 +694,9 @@ class Engine: # Parallelize program based on the planner's results # For now, the completer has to be passed to the planner, # because we may use it to complete the annotation of the backwarkward and update. - parallelizer = Parallelizer(mode, self._planners[mode].completer, - self._dist_contexts[mode]) + parallelizer = Parallelizer( + mode, self._planners[mode].completer, self._dist_contexts[mode] + ) if not all_ranks: parallelizer.parallel(self._cur_rank) else: @@ -623,22 +714,30 @@ class Engine: for ib, block in enumerate(origin_main_prog.blocks): for iop, op in enumerate(block.ops): ref_op = ref_blocks[ib].ops[iop] - assert op.type == ref_op.type, \ - "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type) - ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program( - ref_op) + assert ( + op.type == ref_op.type + ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format( + mode, op.type, ref_mode, ref_op.type + ) + ref_op_dist_attr = ( + ref_dist_context.get_op_dist_attr_for_program(ref_op) + ) dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) def _initialize(self, mode): # Get the current content from the distributed context self._serial_main_progs[mode] = self._dist_contexts[ - mode].serial_main_program + mode + ].serial_main_program self._serial_startup_progs[mode] = self._dist_contexts[ - mode].serial_startup_program + mode + ].serial_startup_program self._dist_main_progs[mode] = self._dist_contexts[ - mode].dist_main_programs + mode + ].dist_main_programs self._dist_startup_progs[mode] = self._dist_contexts[ - mode].dist_startup_programs + mode + ].dist_startup_programs self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars self._optimizer = self._dist_contexts[mode]._serial_optimizer @@ -684,30 +783,33 @@ class Engine: self._executor.run(prune_startup_prog) if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"): - self._set_state_dict(mode, self._strict, self._state_dict, - self._dist_attr) + self._set_state_dict( + mode, self._strict, self._state_dict, self._dist_attr + ) if self._strategy.reinit: self._logger.info("NOTE: parameters will be re-initialized.") dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] self._executor.run(dist_startup_prog) - def fit(self, - train_data, - train_sample_split=None, - batch_size=1, - epochs=1, - steps_per_epoch=None, - log_freq=10, - save_dir=None, - save_freq=1, - valid_data=None, - valid_sample_split=None, - valid_freq=1, - valid_steps=None, - collate_fn=None, - callbacks=None, - verbose=2): + def fit( + self, + train_data, + train_sample_split=None, + batch_size=1, + epochs=1, + steps_per_epoch=None, + log_freq=10, + save_dir=None, + save_freq=1, + valid_data=None, + valid_sample_split=None, + valid_freq=1, + valid_steps=None, + collate_fn=None, + callbacks=None, + verbose=2, + ): """ Trains the model for a fixed number of epochs. If `valid_data` is set, evaluation will be done at the end of each epoch. @@ -776,17 +878,13 @@ class Engine: """ self._mode = 'train' self._inputs_spec, self._labels_spec = self._prepare_data_spec( - train_data, train_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + train_data, train_sample_split, batch_size + ) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) - assert self._mode in self._dist_main_progs, \ - "train model is not ready, please call `engine._prepare_program('train')` first." - train_dataloader = self._prepare_dataloader_from_generator( dataset=train_data, capacity=70, @@ -794,7 +892,8 @@ class Engine: batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, - collate_fn=collate_fn) + collate_fn=collate_fn, + ) fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) @@ -823,21 +922,35 @@ class Engine: self.main_program, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break lr = get_lr(self._optimizer) - logs = self._prepare_logger(outs, epoch, step, lr, fetch_names, - fetch_indices, self._mode) + logs = self._prepare_logger( + outs, + epoch, + step, + lr, + fetch_names, + fetch_indices, + self._mode, + ) cbks.on_batch_end('train', step, logs) if valid_data and (epoch + 1) % valid_freq == 0: - val_logs = self.evaluate(valid_data, valid_sample_split, - batch_size, valid_steps, log_freq, - collate_fn, callbacks, verbose) + val_logs = self.evaluate( + valid_data, + valid_sample_split, + batch_size, + valid_steps, + log_freq, + collate_fn, + callbacks, + verbose, + ) val_logs = { - "val_" + name: val - for name, val in val_logs.items() + "val_" + name: val for name, val in val_logs.items() } logs.update(val_logs) self._switch_mode("train") @@ -849,15 +962,17 @@ class Engine: cbks.on_end('train', logs) return self.history - def evaluate(self, - valid_data, - valid_sample_split=None, - batch_size=1, - steps=None, - log_freq=10, - collate_fn=None, - callbacks=None, - verbose=2): + def evaluate( + self, + valid_data, + valid_sample_split=None, + batch_size=1, + steps=None, + log_freq=10, + collate_fn=None, + callbacks=None, + verbose=2, + ): """ Evaluate the loss and metrics of the model on evaluation data. @@ -906,23 +1021,21 @@ class Engine: """ self._mode = 'eval' self._inputs_spec, self._labels_spec = self._prepare_data_spec( - valid_data, valid_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + valid_data, valid_sample_split, batch_size + ) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) - assert self._mode in self._dist_main_progs, \ - "eval model is not ready, please call `engine._prepare_program('eval')` first." valid_dataloader = self._prepare_dataloader_from_generator( dataset=valid_data, capacity=70, iterable=False, batch_size=batch_size, steps_per_epoch=steps, - collate_fn=collate_fn) + collate_fn=collate_fn, + ) fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) @@ -936,10 +1049,9 @@ class Engine: ) eval_steps = valid_dataloader._steps - cbks.on_begin('eval', { - 'steps': eval_steps, - 'metrics': self._metrics_name() - }) + cbks.on_begin( + 'eval', {'steps': eval_steps, 'metrics': self._metrics_name()} + ) logs = {} for step, _ in enumerate(valid_dataloader): cbks.on_batch_begin('eval', step, logs) @@ -948,24 +1060,28 @@ class Engine: self.main_program, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break - logs = self._prepare_logger(outs, None, step, None, fetch_names, - fetch_indices, self._mode) + logs = self._prepare_logger( + outs, None, step, None, fetch_names, fetch_indices, self._mode + ) cbks.on_batch_end('eval', step, logs) cbks.on_end('eval', logs) self._reset_metrics() return logs - def predict(self, - test_data, - test_sample_split=None, - batch_size=1, - steps=None, - collate_fn=None, - callbacks=None, - verbose=2): + def predict( + self, + test_data, + test_sample_split=None, + batch_size=1, + steps=None, + collate_fn=None, + callbacks=None, + verbose=2, + ): """ Compute the output predictions on testing data. @@ -1011,24 +1127,21 @@ class Engine: """ self._mode = 'predict' self._inputs_spec, self._labels_spec = self._prepare_data_spec( - test_data, test_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + test_data, test_sample_split, batch_size + ) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) - assert self._mode in self._dist_main_progs, \ - "predict model is not ready, please call `engine._prepare_program('predict')` first." - test_dataloader = self._prepare_dataloader_from_generator( dataset=test_data, capacity=70, iterable=False, batch_size=batch_size, steps_per_epoch=steps, - collate_fn=collate_fn) + collate_fn=collate_fn, + ) fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) @@ -1044,41 +1157,45 @@ class Engine: self.main_program, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break - logs = self._prepare_logger(outs, None, step, None, fetch_names, - fetch_indices, self._mode) + logs = self._prepare_logger( + outs, None, step, None, fetch_names, fetch_indices, self._mode + ) cbks.on_batch_end('predict', step, logs) outputs.append(list(logs["outputs"].values())) cbks.on_end('predict', logs) return outputs - def dataloader(self, - dataset, - batch_size=1, - shuffle=False, - drop_last=False, - collate_fn=None, - num_workers=0, - use_buffer_reader=True, - use_shared_memory=True, - timeout=0, - worker_init_fn=None, - epochs=1, - steps_per_epoch=None, - sample_split=1, - mode=None): + def dataloader( + self, + dataset, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + use_shared_memory=True, + timeout=0, + worker_init_fn=None, + epochs=1, + steps_per_epoch=None, + sample_split=1, + mode=None, + ): if mode is not None: self.to_mode(mode) self._inputs_spec, self._labels_spec = self._prepare_data_spec( - dataset, sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + dataset, sample_split, batch_size + ) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) + dataloader = self._prepare_dataloader( dataset, return_list=False, @@ -1092,32 +1209,35 @@ class Engine: timeout=timeout, worker_init_fn=worker_init_fn, epochs=epochs, - steps_per_epoch=steps_per_epoch) + steps_per_epoch=steps_per_epoch, + ) return dataloader - def dataloader_from_generator(self, - dataset, - capacity=70, - use_double_buffer=True, - iterable=True, - use_multiprocess=False, - drop_last=True, - batch_size=1, - epochs=1, - steps_per_epoch=None, - collate_fn=None, - sample_split=1, - mode=None): + def dataloader_from_generator( + self, + dataset, + capacity=70, + use_double_buffer=True, + iterable=True, + use_multiprocess=False, + drop_last=True, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None, + sample_split=1, + mode=None, + ): if mode is not None: self.to_mode(mode) self._inputs_spec, self._labels_spec = self._prepare_data_spec( - dataset, sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + dataset, sample_split, batch_size + ) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) + dataloader = self._prepare_dataloader_from_generator( dataset=dataset, capacity=capacity, @@ -1129,95 +1249,114 @@ class Engine: batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, - collate_fn=collate_fn) + collate_fn=collate_fn, + ) return dataloader - def prepare(self, - inputs_spec=None, - labels_spec=None, - inputs=None, - labels=None, - main_program=None, - startup_program=None, - mode=None): + def prepare( + self, + inputs_spec=None, + labels_spec=None, + inputs=None, + labels=None, + main_program=None, + startup_program=None, + mode=None, + ): if mode is not None: self.to_mode(mode) + + if not self._mode: + raise ValueError( + "Please set mode to be prepared with `prepare(mode=...)`" + ) + + if self._has_prepared[self._mode]: + return + + inputs_spec = self._validate_spec(inputs_spec) + labels_spec = self._validate_spec(labels_spec) + inputs = self._validate_vars(inputs) + labels = self._validate_vars(labels) + + self._orig_main_prog = main_program + self._orig_startup_prog = startup_program if inputs or labels: self._skip_build = True - self._inputs_spec = inputs_spec - self._labels_spec = labels_spec - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec, inputs, labels) - self._orig_main_prog = main_program + inputs, labels = self._prepare_data_tensor( + inputs_spec, labels_spec, inputs, labels + ) if self._orig_main_prog is None: self._orig_main_prog = static.default_main_program() - self._orig_startup_prog = startup_program if self._orig_startup_prog is None: self._orig_startup_prog = static.default_startup_program() - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) elif inputs_spec or labels_spec: - self._inputs_spec = inputs_spec - self._labels_spec = labels_spec self._outside_dataloader = True - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) - self._orig_main_prog = main_program if self._orig_main_prog is None: self._orig_main_prog = static.default_main_program() - self._orig_startup_prog = startup_program if self._orig_startup_prog is None: self._orig_startup_prog = static.default_startup_program() - if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) - else: - self._switch_mode(self._mode) else: - assert self._inputs_spec and self._labels_spec, \ - "Please call the dataloader(...) before calling prepare(...)" + assert ( + self._inputs_spec and self._labels_spec + ), "Please call the dataloader(...) before calling prepare(...)" + + self._inputs_spec, self._labels_spec = inputs_spec, labels_spec + self._inputs, self._labels = inputs, labels + if not self._has_prepared[self._mode]: + self._prepare_program(self._mode) + else: + self._switch_mode(self._mode) def run(self, data=None, feed=None, fetch_list=None, mode=None): if mode is not None: self.to_mode(mode) feed_dict = self._prepare_feed(data, feed, self._mode) fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode) - if self._outside_dataloader and not self._has_prepared_reader[ - self._mode]: + if ( + self._outside_dataloader + and not self._has_prepared_reader[self._mode] + ): self._prepare_reader() - outs = self._executor.run(self.main_program, - feed=feed_dict, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy) - logs = self._prepare_logger(outs, None, None, None, fetch_names, - fetch_indices, self._mode) + outs = self._executor.run( + self.main_program, + feed=feed_dict, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy, + ) + logs = self._prepare_logger( + outs, None, None, None, fetch_names, fetch_indices, self._mode + ) return logs - def _prepare_dataloader(self, - dataset, - return_list=True, - batch_size=1, - shuffle=False, - drop_last=False, - collate_fn=None, - num_workers=0, - use_buffer_reader=True, - use_shared_memory=True, - timeout=0, - worker_init_fn=None, - epochs=1, - steps_per_epoch=None): + def _prepare_dataloader( + self, + dataset, + return_list=True, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + use_shared_memory=True, + timeout=0, + worker_init_fn=None, + epochs=1, + steps_per_epoch=None, + ): if self._strategy.gradient_merge and batch_size is not None: - assert batch_size % self._k_steps == 0, \ - "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) + assert ( + batch_size % self._k_steps == 0 + ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format( + batch_size, self._k_steps + ) batch_size //= self._k_steps dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] - dist_context = self._dist_contexts[self._mode] dist_main_block = dist_main_prog.global_block() # NOTE: Get feed_list, then insert dataloader op with sharded var shape. @@ -1256,31 +1395,36 @@ class Engine: steps_per_epoch=steps_per_epoch, split_data=self._strategy.split_data, data_parallel_world_size=self._dp_world_sizes, - data_parallel_rank=self._dp_ranks) + data_parallel_rank=self._dp_ranks, + ) return dataloader - def _prepare_dataloader_from_generator(self, - dataset, - capacity=None, - use_double_buffer=True, - iterable=True, - return_list=False, - use_multiprocess=False, - drop_last=True, - batch_size=1, - epochs=1, - steps_per_epoch=None, - collate_fn=None): + def _prepare_dataloader_from_generator( + self, + dataset, + capacity=None, + use_double_buffer=True, + iterable=True, + return_list=False, + use_multiprocess=False, + drop_last=True, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None, + ): if self._strategy.gradient_merge and batch_size is not None: - assert batch_size % self._k_steps == 0, \ - "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) + assert ( + batch_size % self._k_steps == 0 + ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format( + batch_size, self._k_steps + ) batch_size //= self._k_steps dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] - dist_context = self._dist_contexts[self._mode] dist_main_block = dist_main_prog.global_block() # NOTE: Get feed_list, then insert dataloader op with sharded var shape. @@ -1316,16 +1460,16 @@ class Engine: collate_fn=collate_fn, split_data=self._strategy.split_data, data_parallel_world_size=self._dp_world_sizes, - data_parallel_rank=self._dp_ranks) + data_parallel_rank=self._dp_ranks, + ) self._prepare_reader() return dataloader def _tune(self, tune_data, tune_sample_split=None, batch_size=1): self._mode = 'train' self._inputs_spec, self._labels_spec = self._prepare_data_spec( - tune_data, tune_sample_split, batch_size) - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + tune_data, tune_sample_split, batch_size + ) self._optimization_tuning(self._mode, tune_data, batch_size) def _validate_spec(self, specs): @@ -1333,46 +1477,39 @@ class Engine: self._k_steps = self._strategy.gradient_merge.k_steps if specs is not None: for i, spec in enumerate(specs): - assert isinstance(spec, InputSpec) + if not isinstance(spec, InputSpec): + raise TypeError( + "'spec' must be object of class `paddle.static.InputSpec`." + ) if spec.name is None: raise ValueError( - "Requires Input[{}].name != None, but receive `None` with {}." - .format(i, spec)) + "Requires Input[{}].name != None, but receive `None` with {}.".format( + i, spec + ) + ) if self._k_steps > 1: shape = list(spec.shape) - assert shape[0] % self._k_steps == 0, \ - "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps) + assert ( + shape[0] % self._k_steps == 0 + ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format( + spec.shape[0], self._k_steps + ) shape[0] //= self._k_steps spec.shape = shape - return specs + return specs or [] + + def _validate_vars(self, vars): + vars = to_list(vars) + if vars is not None: + for i, var in enumerate(vars): + if not isinstance(var, Variable): + raise TypeError("'var' must be a `Variable`.") + return vars or [] def _is_local_var(self, var): var_name = _to_name_str(var) return var_name in self.main_program.global_block().vars - def _get_input_split_info(self, var, dist_context): - # deduce how the input data is split among the cluster - from .utils import _get_comm_group, _get_corresponding_rank - - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) - process_mesh = tensor_dist_attr.process_mesh - dims_mapping = tensor_dist_attr.dims_mapping - - if self._cur_rank not in process_mesh.processes: - rank_id = _get_corresponding_rank(dist_context, process_mesh, - self._cur_rank) - else: - rank_id = self._cur_rank - - batch_size_axis = dims_mapping[0] - if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, - batch_size_axis, rank_id) - return len(group_ranks), group_ranks.index(rank_id) - - return 1, 0 - def _set_recompute_ckpts(self): # NOTE hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here @@ -1381,10 +1518,12 @@ class Engine: # extract ckpts by specific model if isinstance(self._model, paddle.nn.Layer): - if hasattr(self._model, - "gpt") and self._model.__class__.__name__ in [ - 'GPTForPretraining', 'GPTForPretrainingAuto' - ]: + if hasattr( + self._model, "gpt" + ) and self._model.__class__.__name__ in [ + 'GPTForPretraining', + 'GPTForPretrainingAuto', + ]: exact_ckpts = self._model.gpt.checkpoints else: exact_ckpts = recompute.checkpoints @@ -1396,16 +1535,10 @@ class Engine: recompute.checkpoints = exact_ckpts[:] logs = { 'Model Class': self._model.__class__.__name__, - 'Applied Recompute ckpts': exact_ckpts + 'Applied Recompute ckpts': exact_ckpts, } self._logger.info(logs) - def _validate_opt(self, optimizer): - if optimizer is not None: - optimizer._parameter_list = None - optimizer._param_groups = None - return optimizer - def _reset_metrics(self): for metric in self._metrics: metric.reset() @@ -1417,12 +1550,18 @@ class Engine: return metrics_name def _switch_mode(self, mode): + assert ( + mode in self._dist_main_progs + ), "{} model is not ready, please call `prepare()` first.".format(mode) self.to_mode(mode) self._optimizer = self._dist_contexts[mode]._serial_optimizer def to_mode(self, mode): - assert mode in ["train", "eval", "predict"], \ - "mode {} should be one of ['train', 'eval', 'predict']".format(mode) + assert mode in [ + "train", + "eval", + "predict", + ], "mode {} should be one of ['train', 'eval', 'predict']".format(mode) self._mode = mode def _set_state_dict(self, mode, strict, state_dict, dist_attr): @@ -1483,20 +1622,24 @@ class Engine: serial_program = self._serial_main_progs[self._mode] dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] dist_context = self._dist_contexts[self._mode] - self._saver.save(path, - serial_program=serial_program, - dist_main_program=dist_main_prog, - dist_context=dist_context) + self._saver.save( + path, + serial_program=serial_program, + dist_main_program=dist_main_prog, + dist_context=dist_context, + ) else: assert "predict" in self._dist_main_progs feed_vars = self._feed_vars["predict"]['inputs'] fetch_vars = self._fetch_vars["predict"]['outputs'] dist_main_prog = self._dist_main_progs["predict"][self._cur_rank] - self._saver.save_inference_model(path, - feed_vars, - fetch_vars, - self._executor, - program=dist_main_prog) + self._saver.save_inference_model( + path, + feed_vars, + fetch_vars, + self._executor, + program=dist_main_prog, + ) def load(self, path, strict=True, load_optimizer=True): """ @@ -1508,10 +1651,10 @@ class Engine: strict (bool, optional): Whether to skip the loading of mismatch parameter or raise an error when mismatch happens (not found the parameter in file storing model states of or receives a - mismatch shape). Default: False. + mismatch shape). Default: True. load_optimizer (bool, optional): If True, the stored optimizer states is restored. Otherwise, the optimizer states is initialized - from scratch. Default: False. + from scratch. Default: True. Returns: None @@ -1546,10 +1689,11 @@ class Engine: """ self._strict = strict self._state_dict, self._dist_attr = self._saver.load( - path, load_optimizer) + path, load_optimizer + ) return self._state_dict, self._dist_attr - def cost(self, inputs_spec=None, labels_spec=None, mode="train"): + def cost(self, inputs_spec=None, labels_spec=None, mode=None): """ Get and Print cost, including memory of every rank, max memory among all ranks, and the global cost of one step based on @@ -1560,7 +1704,7 @@ class Engine: Args: inputs_spec(InputSpec): The specification of inputs. Default: None. labels_spec(InputSpec): The specification of labels. Default: None. - mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: "train". + mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None. Returns: Return the global execution time (ms) and max memory (B). @@ -1568,29 +1712,44 @@ class Engine: """ # Check parallel mode if self._strategy.auto_mode == "full": - print( + self._logger.info( "The cost will be calcudated in the search process when the auto mode is full." ) return # Check mode - accepted_modes = ["train", "predict", "eval"] - if mode not in accepted_modes: - raise ValueError("The mode {} is not in accepted modes {}".format( - mode, accepted_modes)) + mode = mode if mode is not None else self._mode + assert mode is not None, "Please set mode." + if mode not in self._has_prepared: + raise ValueError( + "The mode {} is not in accepted modes {}".format( + mode, list(self._has_prepared.keys()) + ) + ) self.to_mode(mode) - if inputs_spec is not None: - self._inputs_spec, self._labels_spec = inputs_spec, labels_spec - self._inputs, self._labels = self._prepare_data_tensor( - self._inputs_spec, self._labels_spec) + if inputs_spec is not None and not self._has_prepared[mode]: + self._inputs_spec = self._validate_spec(inputs_spec) + self._labels_spec = self._validate_spec(labels_spec) self._build(mode) self._plan(mode) else: if _non_static_mode() or self._dygraph_mode: raise ValueError( - "Please call `engine._prepare_program('mode')` firstly when in the static graph mode." + "Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`." ) + else: + self._logger.info( + "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`." + ) + program = paddle.static.default_main_program() + if ( + not program.global_block().ops + or not program.global_block().ops + ) and not self._has_prepared[mode]: + raise ValueError( + "Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`." + ) # Estimate the exec cost and max memory global_cost, max_memory = get_cost_from_engine(self, mode) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index cccff89b3d991c45c59f7d732b6f71b05d041527..6dc722b53eadee098046dfad364f723dd5662587 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -23,12 +23,18 @@ from functools import reduce import paddle.fluid.core as core from paddle.distributed.fleet.meta_optimizers.common import OpRole -from paddle.distributed.auto_parallel.process_group import get_all_process_groups +from paddle.distributed.auto_parallel.process_group import ( + get_all_process_groups, +) from paddle.fluid.io import is_parameter, is_belong_to_optimizer -from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute +from paddle.distributed.auto_parallel.dist_attribute import ( + TensorDistributedAttribute, + OperatorDistributedAttribute, +) __not_shape_var_type__ = [ - core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES + core.VarDesc.VarType.READER, + core.VarDesc.VarType.STEP_SCOPES, ] @@ -39,7 +45,8 @@ def get_logger(log_level, name="auto_parallel"): logger.setLevel(log_level) log_handler = logging.StreamHandler() log_format = logging.Formatter( - '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') + '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' + ) log_handler.setFormatter(log_format) logger.addHandler(log_handler) return logger @@ -114,8 +121,11 @@ def verify_shard_spec(shard_spec, tensor_shape, process_mesh): if not verify_dims_mapping(dims_mapping, process_mesh): return False for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0: + if ( + dims_mapping[i] != -1 + and tensor_shape[i] > 0 + and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0 + ): return False return True @@ -141,14 +151,17 @@ def compute_compatible_dims_mapping(dims_mapping_list): return None length = len(dims_mapping_list[0]) for dims_mapping in dims_mapping_list: - assert dims_mapping is not None, \ - "Dims mapping must not be None for compatible computation" - assert len(dims_mapping) == length, \ - "The length of dims_mapping in list must be same for compatible computation." + assert ( + dims_mapping is not None + ), "Dims mapping must not be None for compatible computation" + assert ( + len(dims_mapping) == length + ), "The length of dims_mapping in list must be same for compatible computation." compatible_result = [] for dim_mappings in zip(*dims_mapping_list): compatible_dim_mapping = compute_compatible_dim_mapping( - list(dim_mappings)) + list(dim_mappings) + ) if compatible_dim_mapping is None: return None compatible_result.append(compatible_dim_mapping) @@ -161,7 +174,10 @@ def compute_compatible_process_mesh(process_mesh_list): return compatible_process_mesh for process_mesh in process_mesh_list: if process_mesh is not None: - if compatible_process_mesh is None or compatible_process_mesh == process_mesh: + if ( + compatible_process_mesh is None + or compatible_process_mesh == process_mesh + ): compatible_process_mesh = process_mesh else: return None @@ -201,15 +217,18 @@ def remove_distributed_attr_suffix(name): def check_distributed_attr_for_program(program, dist_context=None): from .dist_context import get_default_distributed_context + if dist_context is None: dist_context = get_default_distributed_context() - assert dist_context.is_initialized_for_program(), \ - "Distributed attributes must be initialized before check." + assert ( + dist_context.is_initialized_for_program() + ), "Distributed attributes must be initialized before check." for block in program.blocks: for tensor in block.vars.values(): dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - tensor) + tensor + ) if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): return False for op in block.ops: @@ -229,6 +248,7 @@ def print_program_with_dist_attr(program, dist_context=None): lock.acquire() from .dist_context import get_default_distributed_context from .dist_context import set_default_distributed_context + if dist_context is None: dist_context = get_default_distributed_context() print(program, flush=True) @@ -242,7 +262,7 @@ def print_program_with_dist_attr(program, dist_context=None): def _get_comm_group(processes, shape, axis, rank): """ - Given a rank and the processes mesh the rank belongs to, + Given a rank and the processes mesh the rank belongs to, compute the communication peers of the rank based on the give axis in the mesh. Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. @@ -256,7 +276,8 @@ def _get_comm_group(processes, shape, axis, rank): # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous assert rank in processes, "rank [{}] is NOT in processes group {}".format( - rank, processes) + rank, processes + ) rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinates_in_group = [coordinate[:] for i in range(shape[axis])] @@ -276,7 +297,7 @@ def _get_comm_group(processes, shape, axis, rank): def _get_idx_in_axis(processes, shape, axis, rank): """ - Given a rank and the processes mesh the rank belongs to, + Given a rank and the processes mesh the rank belongs to, compute the index of the rank in given axis. Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. @@ -297,20 +318,20 @@ def _coordinate2linear_idx(mesh_shape, coordinate): """ convert a coordinate in multidimensional mesh space into a scala idx in linear space. - it use Row-major order for dimension conversion. + it use Row-major order for dimension conversion. so it has: [most_significant_dim, ..., least_significant_dim] - assume: + assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] - linear_idx of a n dimensional coordinate is: + linear_idx of a n dimensional coordinate is: I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + - I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + - I[n-3] * ( S[n-4] * .... S[0]) + + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + + I[n-3] * ( S[n-4] * .... S[0]) + ... - I[1] * ( S[0]) + + I[1] * ( S[0]) + I[0] """ @@ -325,14 +346,19 @@ def _coordinate2linear_idx(mesh_shape, coordinate): assert len(mesh_shape) == len( coordinate ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( - mesh_shape, coordinate) + mesh_shape, coordinate + ) for i in range(len(mesh_shape)): - assert coordinate[ - i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( - i, coordinate) - assert coordinate[i] < mesh_shape[ - i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( - i, mesh_shape, coordinate) + assert ( + coordinate[i] >= 0 + ), "index in dimension [{}] is least than zero. coordinate: {}".format( + i, coordinate + ) + assert ( + coordinate[i] < mesh_shape[i] + ), "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( + i, mesh_shape, coordinate + ) base = mesh_shape[-1] linear_idx = coordinate[-1] @@ -350,7 +376,7 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): mapping a linear scala into multidimensional mesh space, return it coordinate in that space. it is the inverse function of _coordinate2linear_idx. - assume: + assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] @@ -365,11 +391,13 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): """ assert linear_idx >= 0, "linear index [{}] is least than zero".format( - linear_idx) + linear_idx + ) assert linear_idx < np.prod( mesh_shape ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( - mesh_shape, linear_idx) + mesh_shape, linear_idx + ) base = 1 coordinate = [-1] * len(mesh_shape) @@ -392,15 +420,17 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): coordinate = None for mesh in dist_context.process_meshes: if rank in mesh.processes and mesh.topology == target_mesh.topology: - coordinate = _linear_idx2coordinate(mesh.topology, - mesh.processes.index(rank)) + coordinate = _linear_idx2coordinate( + mesh.topology, mesh.processes.index(rank) + ) break # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # rank) if coordinate is not None: - return target_mesh.processes[_coordinate2linear_idx( - mesh.topology, coordinate)] + return target_mesh.processes[ + _coordinate2linear_idx(mesh.topology, coordinate) + ] else: return target_mesh.processes[0] @@ -412,7 +442,8 @@ def _get_unshard_dist_shape(var, dist_attr): assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( - var_shape, mapping) + var_shape, mapping + ) new_shape = [] for idx in range(len(var_shape)): if var_shape[idx] == -1 or mapping[idx] == -1: @@ -425,13 +456,15 @@ def _get_unshard_dist_shape(var, dist_attr): def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): from .dist_context import get_default_distributed_context + if dist_context is None: dist_context = get_default_distributed_context() for var in dist_main_prog.list_vars(): if var.is_data: tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - var) + var + ) inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) var.desc.set_shape(inverse_shape) dim_mapping = tensor_dist_attr.dims_mapping @@ -441,62 +474,76 @@ def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): def _update_addition_info(addition_info): - """ Update default addition_info with inputs """ + """Update default addition_info with inputs""" add_info = {"epoch": 0, "batch": 0, "batch_size": 0} if not addition_info: return add_info elif not isinstance(addition_info, dict): - raise TypeError("The type of 'addition_info' should be 'dict', " - "but got '{}'.".format(str(type(addition_info)))) + raise TypeError( + "The type of 'addition_info' should be 'dict', " + "but got '{}'.".format(str(type(addition_info))) + ) else: for item, value in addition_info.items(): if item not in ["epoch", "batch", "batch_size"]: raise ValueError( "The key of 'addition_info' should be one of the " "['epoch', 'batch', 'batch_size'], but got '{}'.".format( - str(item))) + str(item) + ) + ) if not isinstance(value, int): raise ValueError( "The value of 'addition_info' should be 'int', " - "but got '{}'.".format(str(type(value)))) + "but got '{}'.".format(str(type(value))) + ) add_info[item] = value return add_info def _check_valid_path(file_path): - """ Validity check of input file path """ + """Validity check of input file path""" if not file_path: return file_path elif isinstance(file_path, list): for file in file_path: if not isinstance(file, str): - raise TypeError("The type of file path should be 'str', " - "but got '{}'.".format(str(type(file)))) + raise TypeError( + "The type of file path should be 'str', " + "but got '{}'.".format(str(type(file))) + ) if not os.path.exists(file): raise ValueError( - "The file path '{}' does not exist.".format(file)) + "The file path '{}' does not exist.".format(file) + ) return file_path else: - raise TypeError("The type of file path should be 'list', " - "but got '{}'.".format(str(type(file_path)))) + raise TypeError( + "The type of file path should be 'list', " + "but got '{}'.".format(str(type(file_path))) + ) def _check_param_dict(param_dict): if not param_dict: raise ValueError("'param_dict' cannot be None.") elif not isinstance(param_dict, dict): - raise TypeError("The type of 'param_dict' should be 'dict', " - "but got '{}'.".format(str(type(param_dict)))) + raise TypeError( + "The type of 'param_dict' should be 'dict', " + "but got '{}'.".format(str(type(param_dict))) + ) else: for name, value in param_dict.items(): if not isinstance(name, str): raise TypeError( "The type of key of 'param_dict' should be 'str', " - "but got '{}'.".format(str(type(name)))) + "but got '{}'.".format(str(type(name))) + ) if not isinstance(value, paddle.fluid.LoDTensor): raise TypeError( "The type of value of 'param_dict' should be 'LoDTensor', " - "but got '{}'.".format(str(type(value)))) + "but got '{}'.".format(str(type(value))) + ) return param_dict @@ -504,35 +551,42 @@ def _check_dist_attr(dist_attr): if not dist_attr: return dist_attr elif not isinstance(dist_attr, dict): - raise TypeError("The type of 'dist_attr' should be 'dict', " - "but got '{}'.".format(str(type(dist_attr)))) + raise TypeError( + "The type of 'dist_attr' should be 'dict', " + "but got '{}'.".format(str(type(dist_attr))) + ) else: for name, value in dist_attr.items(): if not isinstance(name, str): raise TypeError( "The type of param name of 'dist_attr' should be 'str', " - "but got '{}'.".format(str(type(name)))) + "but got '{}'.".format(str(type(name))) + ) if not isinstance(value, dict): raise TypeError( "The type of distributed attribute should be 'dict', " - "but got '{}'".format(str(type(value)))) + "but got '{}'".format(str(type(value))) + ) attr = ['process_shape', 'process_group', 'dims_mapping'] if list(value.keys()) != attr: raise ValueError( "The key of distributed attribute should be " "'['process_shape', 'process_group', 'dims_mapping']', " - "but got {}.".format(str(value.keys()))) + "but got {}.".format(str(value.keys())) + ) return dist_attr -def save_distributed_checkpoint(program, - checkpoint_path, - dist_attr_path, - addition_info=None, - is_integrated=False, - dist_context=None): - """ - Save model parameter state, optimzer state, distributed attribute and +def save_distributed_checkpoint( + program, + checkpoint_path, + dist_attr_path, + addition_info=None, + is_integrated=False, + dist_context=None, +): + """ + Save model parameter state, optimzer state, distributed attribute and additional information of each rank. Args: @@ -569,11 +623,12 @@ def save_distributed_checkpoint(program, else: # TODO: integrate param before save raise NotImplementedError( - "Integrating parameter has not been implemented.") + "Integrating parameter has not been implemented." + ) def load_distributed_checkpoint(checkpoint_path, dist_attr_path): - """ + """ Load parameter, optimizer, distributed attribute and addition_info. Args: @@ -583,7 +638,7 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): Returns: param_dict(dict): parameters' value of all ranks. dist_attr(dict): parameters' distributed attribute. - addition_info(dict): additional information user saved in last training. + addition_info(dict): additional information user saved in last training. Notes: The return, 'addition_info', is belonging to the first file of checkpoint_path by default. @@ -591,16 +646,16 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): Examples: .. code-block:: python - ckpt_path = ['./model_state_rank0.pdmodel', + ckpt_path = ['./model_state_rank0.pdmodel', './model_state_rank1.pdmodel'] - dist_attr_path = ['./dist_attr_rank0.pdattr', + dist_attr_path = ['./dist_attr_rank0.pdattr', './dist_attr_rank1.pdattr'] param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) """ - assert _check_valid_path(checkpoint_path), \ - "'checkpoint_path' cannot be None." - assert _check_valid_path(dist_attr_path), \ - "'dist_attr_path' cannot be None." + assert _check_valid_path( + checkpoint_path + ), "'checkpoint_path' cannot be None." + assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None." state_dict_info = _load_distributed_state_dict(checkpoint_path) dist_attr = _load_distributed_attribute(dist_attr_path) @@ -609,11 +664,10 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path): return param_dict, dist_attr, addition_info -def load_checkpoint_into_program(checkpoint_path, - dist_attr_path, - program, - dist_context=None): - """ +def load_checkpoint_into_program( + checkpoint_path, dist_attr_path, program, dist_context=None +): + """ Load parameter, optimizer, distributed attribute and addition_info into model. Args: @@ -624,7 +678,7 @@ def load_checkpoint_into_program(checkpoint_path, Returns: addition_info(dict): user saved in last train. - + Notes: The return, 'addition_info', is belonging to the first file of checkpoint_path by default. @@ -632,19 +686,19 @@ def load_checkpoint_into_program(checkpoint_path, .. code-block:: python exe.run(startup_program) - ckpt_path = ['./model_state_rank0.pdmodel', + ckpt_path = ['./model_state_rank0.pdmodel', './model_state_rank1.pdmodel'] - dist_attr_path = ['./dist_attr_rank0.pdattr', + dist_attr_path = ['./dist_attr_rank0.pdattr', './dist_attr_rank1.pdattr'] load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.fluid.framework.Program) - assert _check_valid_path(checkpoint_path), \ - "'checkpoint_path' cannot be None." - assert _check_valid_path(dist_attr_path), \ - "'dist_attr_path' cannot be None." + assert _check_valid_path( + checkpoint_path + ), "'checkpoint_path' cannot be None." + assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None." if dist_context is None: dist_context = get_default_distributed_context() all_state_dict_info = _load_distributed_state_dict(checkpoint_path) @@ -652,16 +706,16 @@ def load_checkpoint_into_program(checkpoint_path, all_cur_dist_attr = get_dist_attr(program, dist_context) all_param_dict = all_state_dict_info["model"] addition_info = all_state_dict_info["addition_info"] - sliced_param_dict = merge_and_slice_parameter(all_param_dict, - all_pre_dist_attr, - all_cur_dist_attr) + sliced_param_dict = merge_and_slice_parameter( + all_param_dict, all_pre_dist_attr, all_cur_dist_attr + ) load_parameter_into_program(sliced_param_dict, program) return addition_info def load_parameter_into_program(param_dict, program): - """ + """ Load parameters into program. Args: @@ -676,28 +730,31 @@ def load_parameter_into_program(param_dict, program): def _save_distributed_attribute(program, dist_attr_path, dist_context): - """ Save distributed attribute of all parameters """ + """Save distributed attribute of all parameters""" # TODO: just save a complete distributed attribute file rank_id = paddle.distributed.get_rank() - dist_attr_name = os.path.join(dist_attr_path, - "dist_attr_rank{}.pdattr".format(rank_id)) + dist_attr_name = os.path.join( + dist_attr_path, "dist_attr_rank{}.pdattr".format(rank_id) + ) dist_attr_dict = { "model": get_dist_attr(program, dist_context), - "world_size": paddle.distributed.get_world_size() + "world_size": paddle.distributed.get_world_size(), } paddle.save(dist_attr_dict, dist_attr_name) logging.info( - "Already saved distributed attribute to '{}'.".format(dist_attr_path)) + "Already saved distributed attribute to '{}'.".format(dist_attr_path) + ) def _load_distributed_attribute(dist_attr_path): - """ Load parameters' distributed attribute from dist_attr_path """ + """Load parameters' distributed attribute from dist_attr_path""" total_dist_attr = {} for dist_attr_file in dist_attr_path: dist_attr = paddle.load(dist_attr_file) pre_world_size = dist_attr["world_size"] - assert pre_world_size == len(dist_attr_path), \ - "The number of 'dist_attr_path' must be equal to the last training world size." + assert pre_world_size == len( + dist_attr_path + ), "The number of 'dist_attr_path' must be equal to the last training world size." for name, attr in dist_attr["model"].items(): if name not in total_dist_attr: total_dist_attr[name] = attr @@ -706,27 +763,29 @@ def _load_distributed_attribute(dist_attr_path): def _save_distributed_state_dict(program, addition_info, checkpoint_path): - """ Save parameters' state_dict """ + """Save parameters' state_dict""" rank = paddle.distributed.get_rank() - ckpt_file_name = os.path.join(checkpoint_path, - "model_state_rank{}.pdmodel".format(rank)) + ckpt_file_name = os.path.join( + checkpoint_path, "model_state_rank{}.pdmodel".format(rank) + ) state_dict = { "model": program.state_dict(), "world_size": paddle.distributed.get_world_size(), - "addition_info": addition_info + "addition_info": addition_info, } paddle.save(state_dict, ckpt_file_name) logging.info("Already saved model to '{}'.".format(checkpoint_path)) def _load_distributed_state_dict(checkpoint_path): - """ Load parameters' state_dict from checkpoint_path """ + """Load parameters' state_dict from checkpoint_path""" all_state_dict = {} for idx, ckpt_file in enumerate(checkpoint_path): state_dict_info = paddle.load(ckpt_file, return_numpy=True) pre_world_size = state_dict_info["world_size"] - assert pre_world_size == len(checkpoint_path), \ - "The number of 'checkpoint_path' must be equal to the last training world size." + assert pre_world_size == len( + checkpoint_path + ), "The number of 'checkpoint_path' must be equal to the last training world size." if idx == 0: addition_info = state_dict_info["addition_info"] for name, value in state_dict_info["model"].items(): @@ -737,13 +796,13 @@ def _load_distributed_state_dict(checkpoint_path): all_state_dict_info = { "model": all_state_dict, - "addition_info": addition_info + "addition_info": addition_info, } return all_state_dict_info def get_dist_attr(program, dist_context=None): - """ + """ Get distributed attribute of current rank. Args: @@ -758,13 +817,14 @@ def get_dist_attr(program, dist_context=None): for var in program.list_vars(): if is_parameter(var) or is_belong_to_optimizer(var): tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - var) + var + ) process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping dist_attr[var.name] = { "process_shape": process_mesh.topology, "process_group": process_mesh.processes, - "dims_mapping": dims_mapping + "dims_mapping": dims_mapping, } return dist_attr @@ -782,19 +842,26 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): dist_param_dict(dict): parameters' value of current rank. """ assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." - assert isinstance(dist_param_dict, dict), \ - "The type of 'dist_param_dict' should be 'dict', but got {}.".format( - str(type(dist_param_dict))) + assert isinstance( + dist_param_dict, dict + ), "The type of 'dist_param_dict' should be 'dict', but got {}.".format( + str(type(dist_param_dict)) + ) for name, value in dist_param_dict.items(): if not isinstance(name, str): - raise TypeError("The key of 'dist_param_dict' is parameter's name, " - "and its type should be 'str', but got {}.".format( - str(type(name)))) + raise TypeError( + "The key of 'dist_param_dict' is parameter's name, " + "and its type should be 'str', but got {}.".format( + str(type(name)) + ) + ) if not isinstance(value, list) or not all( - isinstance(v, np.ndarray) for v in value): + isinstance(v, np.ndarray) for v in value + ): raise TypeError( "The value of 'dist_param_dict' is parameter's value of all ranks, " - "and its type should be 'list(numpy.ndarray)'.") + "and its type should be 'list(numpy.ndarray)'." + ) if cur_dist_attr is None: return {} @@ -822,7 +889,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): cur_dims_mapping = cur_attr["dims_mapping"] if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: complete_param = _merge_parameter_with_dist_attr( - pre_param, pre_attr) + pre_param, pre_attr + ) dist_param_dict[var_name] = complete_param else: complete_param = pre_param[0] @@ -830,7 +898,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: sliced_param = _slice_parameter_with_dist_attr( - complete_param, cur_attr) + complete_param, cur_attr + ) dist_param_dict[var_name] = sliced_param for var_name in pre_dist_attr: @@ -841,67 +910,81 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): if param_not_in_pre: warnings.warn( "Parameters '{}' are not found in last training process.".format( - str(param_not_in_pre))) + str(param_not_in_pre) + ) + ) if param_not_in_cur: warnings.warn( "Parameters '{}' are not found in current training process.".format( - str(param_not_in_cur))) + str(param_not_in_cur) + ) + ) return dist_param_dict def _merge_parameter_with_dist_attr(param_list, dist_attr): - """ Merge parameter with distributed attribute """ + """Merge parameter with distributed attribute""" from .reshard import Resharder dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # get the complete shape of the parameter - complete_shape = Resharder.compute_complete_shape(param_list[0].shape, - process_shape, - dims_mapping) + complete_shape = Resharder.compute_complete_shape( + param_list[0].shape, process_shape, dims_mapping + ) # merge the parameter with dist_attr partition_param_list = [] merged_partiton = [] for process in process_group: partition_index = Resharder.compute_partition_index( - process, complete_shape, dims_mapping, process_shape, process_group) + process, complete_shape, dims_mapping, process_shape, process_group + ) index = process_group.index(process) if partition_index not in merged_partiton: merged_partiton.append(partition_index) - _merge_parameter(partition_param_list, param_list[index], - partition_index, complete_shape) + _merge_parameter( + partition_param_list, + param_list[index], + partition_index, + complete_shape, + ) - assert len(partition_param_list) == 1 or not partition_param_list, \ - "Fail to merge parameter" + assert ( + len(partition_param_list) == 1 or not partition_param_list + ), "Fail to merge parameter" complete_param = partition_param_list[0][0] return complete_param def _slice_parameter_with_dist_attr(param, dist_attr): - """ Slice parameter with distributed attribute """ - param = np.array(param) if isinstance(param, - paddle.fluid.LoDTensor) else param + """Slice parameter with distributed attribute""" + param = ( + np.array(param) if isinstance(param, paddle.fluid.LoDTensor) else param + ) dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # slice the parameter with dist_attr - partition_index_list = _get_split_indices(param.shape, dims_mapping, - process_shape, process_group) - sliced_param_list = _slice_parameter(param, partition_index_list, - len(partition_index_list)) + partition_index_list = _get_split_indices( + param.shape, dims_mapping, process_shape, process_group + ) + sliced_param_list = _slice_parameter( + param, partition_index_list, len(partition_index_list) + ) # get the current parameter's index in sliced_param_list rank_id = paddle.distributed.get_rank() - sliced_param_index = _get_sliced_param_index(rank_id, param.shape, - dims_mapping, process_shape, - process_group) + sliced_param_index = _get_sliced_param_index( + rank_id, param.shape, dims_mapping, process_shape, process_group + ) sliced_param = sliced_param_list[sliced_param_index] return sliced_param -def _merge_parameter(partition_param_list, param, partition_index, - complete_shape): +def _merge_parameter( + partition_param_list, param, partition_index, complete_shape +): """ Merge partitial parameters to a complete one. @@ -935,19 +1018,30 @@ def _merge_parameter(partition_param_list, param, partition_index, else: i = 0 while i < len(partition_param_list): - concat_axis, first_order, new_partition = Resharder.compute_concat_info( - partition_param_list[i][1], partition_index) + ( + concat_axis, + first_order, + new_partition, + ) = Resharder.compute_concat_info( + partition_param_list[i][1], partition_index + ) if concat_axis != -1: if first_order == 0: new_param = np.concatenate( - (partition_param_list[i][0], param), axis=concat_axis) + (partition_param_list[i][0], param), axis=concat_axis + ) else: new_param = np.concatenate( - (param, partition_param_list[i][0]), axis=concat_axis) + (param, partition_param_list[i][0]), axis=concat_axis + ) partition_param_list.pop(i) - _merge_parameter(partition_param_list, new_param, new_partition, - complete_shape) + _merge_parameter( + partition_param_list, + new_param, + new_partition, + complete_shape, + ) break i += 1 @@ -975,19 +1069,21 @@ def _slice_parameter(complete_param, partition_index_list, length): """ sliced_param_list = [] axis = len(complete_param.shape) - length - sliced_param = np.split(complete_param, - partition_index_list[axis], - axis=axis) + sliced_param = np.split( + complete_param, partition_index_list[axis], axis=axis + ) if length == 1: return sliced_param for param in sliced_param: sliced_param_list.extend( - _slice_parameter(param, partition_index_list, length - 1)) + _slice_parameter(param, partition_index_list, length - 1) + ) return sliced_param_list -def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, - process_group): +def _get_sliced_param_index( + rank, complete_shape, dims_mapping, process_shape, process_group +): """ Get sliced_param's index of current rank in all sliced parameters list. @@ -1006,7 +1102,7 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, process_group = [0, 1, 2] slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) - # slice_param: + # slice_param: # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] index = _get_sliced_param_index(rank, complete_shape, dims_mapping @@ -1015,10 +1111,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, """ from .reshard import Resharder - partition_index = Resharder.compute_partition_index(rank, complete_shape, - dims_mapping, - process_shape, - process_group) + partition_index = Resharder.compute_partition_index( + rank, complete_shape, dims_mapping, process_shape, process_group + ) sliced_param_index = 0 for i, shape in enumerate(complete_shape): if dims_mapping[i] == -1: @@ -1033,8 +1128,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, return sliced_param_index -def _get_split_indices(complete_shape, dims_mapping, process_shape, - process_group): +def _get_split_indices( + complete_shape, dims_mapping, process_shape, process_group +): """ Get split indices of every dimension. @@ -1059,15 +1155,20 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape, split_indices_list = [] for process in process_group: partition_index = Resharder.compute_partition_index( - process, complete_shape, dims_mapping, process_shape, process_group) + process, complete_shape, dims_mapping, process_shape, process_group + ) if split_indices_list: for dim in range(len(partition_index)): split_indices_list[dim].extend(partition_index[dim]) else: split_indices_list = partition_index split_indices_list = list( - map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list, - complete_shape)) + map( + lambda x, y: list(set(x) - set([y]) - set([0])), + split_indices_list, + complete_shape, + ) + ) split_indices_list = [sorted(x) for x in split_indices_list] return split_indices_list @@ -1086,8 +1187,10 @@ def set_grad_var_shape(program, dist_context): if int(op.attr('op_role')) != int(OpRole.Backward): continue - if int(block.ops[idx-1].attr('op_role')) == int(OpRole.Forward) or \ - int(block.ops[idx-1].attr('op_role')) == 257: + if ( + int(block.ops[idx - 1].attr('op_role')) == int(OpRole.Forward) + or int(block.ops[idx - 1].attr('op_role')) == 257 + ): appended_grad_times += 1 if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: @@ -1105,61 +1208,102 @@ def set_grad_var_shape(program, dist_context): continue if var_name in grad_var_to_var[appended_grad_times]: forward_var_name = grad_var_to_var[appended_grad_times][ - var_name] + var_name + ] else: - forward_var_name = var_name[:var_name.find("@GRAD")] + forward_var_name = var_name[: var_name.find("@GRAD")] if op.type in [ - "c_allreduce_sum", "c_identity", "scale", "cast", - "fill_any_like" + "c_allreduce_sum", + "c_identity", + "scale", + "cast", + "fill_any_like", ]: forward_var_name = op.input_arg_names[0] - elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": + elif ( + op.type == "matmul_v2_grad" + or op.type == "matmul_grad" + or op.type == "mul_grad" + ): forward_var_name = None for output_name in op.output_names: if var_name in op.output(output_name): assert "@GRAD" in output_name - input_name = output_name[:output_name.find("@GRAD")] + input_name = output_name[: output_name.find("@GRAD")] assert len(op.input(input_name)) == 1 forward_var_name = op.input(input_name)[0] assert forward_var_name is not None need_set_shape_list = [ - "reshape2_grad", "softmax_with_cross_entropy_grad", - "transpose2_grad", "softmax_grad", "cross_entropy_grad2", - "dropout_grad", "tanh_grad", "slice", "assign", - "matmul_v2_triple_grad", "elementwise_add_triple_grad", - "fill_constant", "sqrt_grad", + "reshape2_grad", + "softmax_with_cross_entropy_grad", + "transpose2_grad", + "softmax_grad", + "cross_entropy_grad2", + "dropout_grad", + "tanh_grad", + "slice", + "assign", + "matmul_v2_triple_grad", + "elementwise_add_triple_grad", + "fill_constant", + "sqrt_grad", "fused_softmax_mask_upper_triangle_grad", - "flatten_contiguous_range_grad", "relu_grad" + "flatten_contiguous_range_grad", + "relu_grad", ] forward_list = [ - "reshape2", "softmax_with_cross_entropy", "transpose2", - "softmax", "cross_entropy2", "dropout", "tanh", - ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", - "elementwise_add_grad_grad", "shape", "sqrt", - "fused_softmax_mask_upper_triangle", "flatten_contiguous_range", - "relu" + "reshape2", + "softmax_with_cross_entropy", + "transpose2", + "softmax", + "cross_entropy2", + "dropout", + "tanh", + ["slice_grad", "c_allgather"], + "assign", + "matmul_v2_grad_grad", + "elementwise_add_grad_grad", + "shape", + "sqrt", + "fused_softmax_mask_upper_triangle", + "flatten_contiguous_range", + "relu", ] if op.type in need_set_shape_list: for forward_op in block.ops: idx = need_set_shape_list.index(op.type) forward_op_name = forward_list[idx] - if forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names: - op_dist_attr = dist_context.get_op_dist_attr_for_program( - forward_op) + if ( + forward_op.type in forward_op_name + and forward_var_name in forward_op.input_arg_names + ): + op_dist_attr = ( + dist_context.get_op_dist_attr_for_program( + forward_op + ) + ) break forward_input_dist_attr = op_dist_attr.get_input_dist_attr( - forward_var_name) - assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}" + forward_var_name + ) + assert ( + forward_input_dist_attr is not None + ), f"{forward_var_name, str(op)}" forward_var = vars[forward_var_name] - forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( - forward_var) + forward_var_dist_attr = ( + dist_context.get_tensor_dist_attr_for_program(forward_var) + ) assert forward_var_dist_attr is not None grad_var = vars[var_name] - ref_shape = infer_shape(block, forward_var, forward_var_dist_attr, - forward_input_dist_attr) + ref_shape = infer_shape( + block, + forward_var, + forward_var_dist_attr, + forward_input_dist_attr, + ) if list(grad_var.shape) != ref_shape: grad_var.desc.set_shape(ref_shape) @@ -1171,28 +1315,33 @@ OpRole = core.op_proto_and_checker_maker.OpRole def is_forward_op(op): op_role = int(op.attr('op_role')) - return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.Forward) - or op_role == int(OpRole.Loss)) + return OP_ROLE_KEY in op.attr_names and ( + op_role == int(OpRole.Forward) or op_role == int(OpRole.Loss) + ) def is_backward_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) + return OP_ROLE_KEY in op.attr_names and int( + op.all_attrs()[OP_ROLE_KEY] + ) & int(OpRole.Backward) def is_optimize_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) + return OP_ROLE_KEY in op.attr_names and int( + op.all_attrs()[OP_ROLE_KEY] + ) & int(OpRole.Optimize) def is_lr_sched_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize.LRSched) + return OP_ROLE_KEY in op.attr_names and int( + op.all_attrs()[OP_ROLE_KEY] + ) & int(OpRole.Optimize.LRSched) def is_loss_op(op): - return OP_ROLE_KEY in op.attr_names and \ - int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss)) + return OP_ROLE_KEY in op.attr_names and int( + op.all_attrs()[OP_ROLE_KEY] + ) == (int(OpRole.Forward) | int(OpRole.Loss)) def is_loss_grad_op(op): @@ -1203,8 +1352,9 @@ def is_loss_grad_op(op): def is_gradient_clip_op(op): - return op.desc.has_attr("op_namescope") \ - and op.desc.attr("op_namescope").startswith("/gradient_clip") + return op.desc.has_attr("op_namescope") and op.desc.attr( + "op_namescope" + ).startswith("/gradient_clip") def is_prim_op(op): @@ -1215,8 +1365,9 @@ def get_loss_op(block): loss_ops = [] for op in block.ops: if is_loss_op(op): - assert len(op.desc.output_arg_names() - ) == 1, "loss op should only output loss var" + assert ( + len(op.desc.output_arg_names()) == 1 + ), "loss op should only output loss var" loss_ops.append(op) assert len(loss_ops) == 1, "num of loss op is not equal to one" @@ -1236,7 +1387,8 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - new_op, process_mesh, ref_mapping, ctx): + new_op, process_mesh, ref_mapping, ctx +): assert process_mesh is not None assert ref_mapping is not None @@ -1270,9 +1422,11 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) + assert ( + mapping == -1 + ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format( + op_desc.type(), idx, mapping + ) batch_dim_mappings.append(dims_mapping[0]) for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) @@ -1282,23 +1436,31 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): - assert mapping == -1, \ - "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) + assert ( + mapping == -1 + ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format( + op_desc.type(), idx, mapping + ) batch_dim_mappings.append(dims_mapping[0]) else: - assert dims_mapping[0] == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ - .format(op_desc.type(), mapping) + assert ( + dims_mapping[0] == -1 + ), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part.".format( + op_desc.type(), mapping + ) if len(dims_mapping) > 2: for idx, mapping in enumerate(dims_mapping[2:]): - assert mapping == -1, \ - "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ - .format(op_desc.type(), idx, mapping) + assert ( + mapping == -1 + ), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part.".format( + op_desc.type(), idx, mapping + ) batch_dim_mappings.append(dims_mapping[1]) compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) - assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + assert ( + compatible_dim_mapping is not None + ), "There is no compatible dim mapping." for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: @@ -1344,8 +1506,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i + new_idx = ( + max_dims_mapping_len - input_dims_mapping_lens[arg_name] + ) + i new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] dims_mapping_list.append(new_dims_mapping) else: @@ -1357,7 +1520,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): dims_mapping_list.append(dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) - assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + assert ( + compatible_dims_mapping is not None + ), "There is no compatible dim mapping." for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: @@ -1365,55 +1530,64 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): -1 for _ in range(input_dims_mapping_lens[arg_name]) ] for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - - input_dims_mapping_lens[arg_name]) + i + new_idx = ( + max_dims_mapping_len - input_dims_mapping_lens[arg_name] + ) + i new_dims_mapping[i] = compatible_dims_mapping[new_idx] if new_dims_mapping != input_dims_mapping_dict[arg_name]: op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) changed = True else: if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: - op_dist_attr.set_input_dims_mapping(arg_name, - compatible_dims_mapping) + op_dist_attr.set_input_dims_mapping( + arg_name, compatible_dims_mapping + ) changed = True for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if compatible_dims_mapping != dims_mapping: - op_dist_attr.set_output_dims_mapping(arg_name, - compatible_dims_mapping) + op_dist_attr.set_output_dims_mapping( + arg_name, compatible_dims_mapping + ) changed = True return changed -def get_all_distributed_main_program(serial_program_info, dist_context, - parallelizer): +def get_all_distributed_main_program( + serial_program_info, dist_context, parallelizer +): "Get all distributed main programs by dist_context." from .dist_context import DistributedOperatorContext, DistributedContext + cluster = serial_program_info.cluster copied_parallelizer = copy.deepcopy(parallelizer) all_dist_main_program = [] - ranks = paddle.distributed.get_world_size() if cluster is None else len( - cluster.get_all_devices("GPU")) + ranks = ( + paddle.distributed.get_world_size() + if cluster is None + else len(cluster.get_all_devices("GPU")) + ) for rank_id in range(ranks): used_dist_context = copy.deepcopy(dist_context) used_dist_context._dist_op_context = DistributedOperatorContext() - _, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program( - rank_id, used_dist_context) + ( + _, + _, + dist_startup_program, + dist_main_program, + _, + ) = copied_parallelizer._get_dist_program(rank_id, used_dist_context) all_dist_main_program.append(dist_main_program) return all_dist_main_program class SerialProgramInfo: - - def __init__(self, - train_program, - satrtup_program, - loss, - optimizer, - cluster=None): + def __init__( + self, train_program, satrtup_program, loss, optimizer, cluster=None + ): self._train_program = train_program self._startup_program = satrtup_program self._loss = loss @@ -1442,7 +1616,6 @@ class SerialProgramInfo: def get_standalone_cost_data(distributed_programs): - def _compute_runtime(op_cost, op, vars): runtime = 0 try: @@ -1455,32 +1628,47 @@ def get_standalone_cost_data(distributed_programs): parsed_info = op_config.split("\n") variable = "(Variable)" for info in parsed_info: - variable = "(Variable)" if "(Variable)" in info else "(list" + variable = ( + "(Variable)" if "(Variable)" in info else "(list" + ) if variable in info: - arg_name_lower = info[:info.find(variable) - 1] + arg_name_lower = info[: info.find(variable) - 1] shape_left_boundary = info.find("[") shape_right_boundary = info.find("]") - assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed." - shape = info[shape_left_boundary + - 1:shape_right_boundary].split(",") + assert ( + shape_left_boundary > 0 + and shape_right_boundary > 0 + and shape_right_boundary > shape_left_boundary + ), "Get shape failed." + shape = info[ + shape_left_boundary + 1 : shape_right_boundary + ].split(",") shape = list(map(lambda x: int(x.strip()), shape)) dtype_factor = 1 total_static_input_size += reduce(lambda x, y: x * y, shape) if op.type == "c_embedding": - arg_name_lower = "w" if arg_name_lower == "weight" else "ids" + arg_name_lower = ( + "w" if arg_name_lower == "weight" else "ids" + ) for arg_name in op.input_names: if arg_name.lower() == arg_name_lower: for var_name in op.input(arg_name): var = vars[var_name] total_actual_input_size += reduce( - lambda x, y: x * y, var.shape) + lambda x, y: x * y, var.shape + ) break - assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed." + assert ( + total_static_input_size > 0 and total_actual_input_size > 0 + ), "Get input size failed." - actual_runtime = total_actual_input_size / total_static_input_size * runtime + actual_runtime = ( + total_actual_input_size / total_static_input_size * runtime + ) return actual_runtime import paddle.cost_model as cm + cost_model = cm.CostModel() cost_model.static_cost_data() DEFAULT_MULTIPLE = 2 @@ -1491,13 +1679,16 @@ def get_standalone_cost_data(distributed_programs): "reshape2": "reshape", "unsqueeze2": "unsqueeze", "reduce_sum": "sum", - "elementwise_div": "divide" + "elementwise_div": "divide", } standalone_cost_data = [] # skip ops not_enum_ops = [ - "create_py_reader", "create_double_buffer_reader", "read", "assign" + "create_py_reader", + "create_double_buffer_reader", + "read", + "assign", ] for distributed_program in distributed_programs: cost_data = {} @@ -1507,26 +1698,33 @@ def get_standalone_cost_data(distributed_programs): if op.type in not_enum_ops: cost_data[op.desc.id()] = runtime continue - dtype = str(vars[op.input_arg_names[0]].dtype - ) if op.input_arg_names else "float32" + dtype = ( + str(vars[op.input_arg_names[0]].dtype) + if op.input_arg_names + else "float32" + ) if int(op.attr('op_role')) == int(OpRole.Backward): if "_grad" in op.type: forward_op_name = op.type[:-5] if forward_op_name in OP_NAME_MAPPING.keys(): forward_op_name = OP_NAME_MAPPING[forward_op_name] - op_cost = cost_model.get_static_op_time(forward_op_name, - forward=False, - dtype=dtype) + op_cost = cost_model.get_static_op_time( + forward_op_name, forward=False, dtype=dtype + ) if op_cost: runtime = _compute_runtime(op_cost, op, vars) else: - op_cost = cost_model.get_static_op_time(forward_op_name, - dtype=dtype) + op_cost = cost_model.get_static_op_time( + forward_op_name, dtype=dtype + ) if op_cost: runtime = 2 * _compute_runtime(op_cost, op, vars) elif int(op.attr('op_role')) == int(OpRole.Forward): - op_name = OP_NAME_MAPPING[ - op.type] if op.type in OP_NAME_MAPPING.keys() else op.type + op_name = ( + OP_NAME_MAPPING[op.type] + if op.type in OP_NAME_MAPPING.keys() + else op.type + ) op_cost = cost_model.get_static_op_time(op_name) if op_cost: runtime = _compute_runtime(op_cost, op, vars) @@ -1565,7 +1763,8 @@ def to_list(value): def debug_program(program, path, name): filename = os.path.join( - path, name + '_program' + ".%d" % (paddle.distributed.get_rank())) + path, name + '_program' + ".%d" % (paddle.distributed.get_rank()) + ) with open(filename, 'w') as f: f.write(str(program)) @@ -1599,9 +1798,11 @@ def get_lr(optimizer): return optimizer._learning_rate() else: raise TypeError( - "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ - " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" + " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format( + type(optimizer) ) + ) def initialize_pg_in_full_mode(all_process_groups, cur_rank): @@ -1631,21 +1832,28 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): if is_send: recv_rank = process_group.ranks[1] recv_rank_ip, recv_rank_port = genv.trainer_endpoints[ - recv_rank].split(":") + recv_rank + ].split(":") connect_port = int(recv_rank_port) + magic_num - client_socket = socket.socket(socket.AF_INET, - socket.SOCK_STREAM) + client_socket = socket.socket( + socket.AF_INET, socket.SOCK_STREAM + ) client_socket.connect((recv_rank_ip, connect_port)) client_socket.send(str(cur_rank).encode('utf-8')) rank = client_socket.recv(buff_size).decode('utf-8') rank = int(rank) if rank != recv_rank: raise ValueError( - "Please check comm pair, the recv rank should be {} but got {}." - .format(recv_rank, rank)) + "Please check comm pair, the recv rank should be {} but got {}.".format( + recv_rank, rank + ) + ) else: - print("It is able to instantiate {} as sender now.".format( - process_group.ranks)) + print( + "It is able to instantiate {} as sender now.".format( + process_group.ranks + ) + ) client_socket.close() else: send_rank = process_group.ranks[0] @@ -1657,10 +1865,45 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): has_recv_by_socket.append(rank) else: client_sockets[send_rank].send( - str(cur_rank).encode("utf-8")) + str(cur_rank).encode("utf-8") + ) client_sockets[send_rank].close() - print("It is able to instantiate {} as recver now.". - format(process_group.ranks)) + print( + "It is able to instantiate {} as recver now.".format( + process_group.ranks + ) + ) break process_group.instantiate() server_socket.close() + + +def get_input_split_info(cur_rank, var, dist_context): + # deduce how the input data is split among the cluster + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) + process_mesh = tensor_dist_attr.process_mesh + dims_mapping = tensor_dist_attr.dims_mapping + + if cur_rank not in process_mesh.processes: + rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank) + else: + rank_id = cur_rank + + batch_size_axis = dims_mapping[0] + if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1: + group_ranks = _get_comm_group( + process_mesh.processes, + process_mesh.topology, + batch_size_axis, + rank_id, + ) + return len(group_ranks), group_ranks.index(rank_id) + + return 1, 0 + + +def validate_opt(optimizer): + if optimizer is not None: + optimizer._parameter_list = None + optimizer._param_groups = None + return optimizer diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 0790b286ce53efd3c94ac01dbb6bf06428b80847..34684c6ca41800b45b99aa48ef8e716df0a6fd5e 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -20,12 +20,31 @@ from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid import unique_name from .pass_base import register_pass from paddle.fluid.data_feeder import check_variable_and_dtype, check_type -from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping -from paddle.distributed.auto_parallel.process_group import get_world_process_group -from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists -from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str -from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute -from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op, OP_ROLE_KEY, OpRole +from paddle.distributed.auto_parallel.utils import ( + set_var_dist_attr, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, +) +from paddle.distributed.auto_parallel.process_group import ( + get_world_process_group, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + AutoMixedPrecisionLists, +) +from paddle.fluid.contrib.mixed_precision.fp16_utils import ( + _keep_layer_norm_scale_bias_to_fp32, + _need_keep_fp32, + _valid_types, + _dtype_to_str, +) +from paddle.distributed.auto_parallel.dist_attribute import ( + OperatorDistributedAttribute, +) +from paddle.distributed.auto_parallel.utils import ( + is_forward_op, + is_backward_op, + OP_ROLE_KEY, + OpRole, +) from .auto_parallel_amp import AMPPass world_process_group = get_world_process_group() @@ -39,11 +58,15 @@ __amp_skip_ops__ = [ def set_op_dtype_to_fp16(op): - if op.has_attr('in_dtype') and op.attr( - 'in_dtype') == core.VarDesc.VarType.FP32: + if ( + op.has_attr('in_dtype') + and op.attr('in_dtype') == core.VarDesc.VarType.FP32 + ): op._set_attr('in_dtype', core.VarDesc.VarType.FP16) - if op.has_attr('out_dtype') and op.attr( - 'out_dtype') == core.VarDesc.VarType.FP32: + if ( + op.has_attr('out_dtype') + and op.attr('out_dtype') == core.VarDesc.VarType.FP32 + ): op._set_attr('out_dtype', core.VarDesc.VarType.FP16) if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32: op._set_attr('dtype', core.VarDesc.VarType.FP16) @@ -63,7 +86,12 @@ def _keep_fp32_input(op, in_name): return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} if op_type in ['fused_attention', 'fused_feedforward']: return in_name in { - 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" + 'LnScale', + 'LnBias', + 'Ln2Scale', + 'Ln2Bias', + "Ln1Scale", + "Ln1Bias", } # backward if op_type in ['batch_norm_grad']: @@ -83,8 +111,12 @@ def _keep_fp32_output(op, out_name): return out_name not in {'Y', 'ConvX', 'ConvZ'} if op_type in ['fused_attention', 'fused_feedforward']: return out_name in { - 'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', - 'Ln1Variance' + 'LnMean', + 'LnVariance', + 'Ln2Mean', + 'Ln2Variance', + 'Ln1Mean', + 'Ln1Variance', } # backward if op_type in ['layer_norm_grad']: @@ -95,24 +127,28 @@ def _keep_fp32_output(op, out_name): class FP16State(object): - - def __init__(self, - program, - amp_list, - dist_context, - use_fp16_guard, - input_data_var_names=None): + def __init__( + self, + program, + amp_list, + dist_context, + use_fp16_guard, + input_data_var_names=None, + ): self.program = program self.amp_list = amp_list self.use_fp16_guard = use_fp16_guard self.dist_context = dist_context - self.grad_op_to_op_map = self.dist_context.dist_op_context.grad_op_id_to_op_id + self.grad_op_to_op_map = ( + self.dist_context.dist_op_context.grad_op_id_to_op_id + ) if input_data_var_names: self.input_data_var_names = input_data_var_names else: self.input_data_var_names = [] - self._op_fp16_dict = { - } # op_id --> True/False. 'True' means that the op is should run in fp16 mode. + self._op_fp16_dict = ( + {} + ) # op_id --> True/False. 'True' means that the op is should run in fp16 mode. # a trick to determine leaf tensor node in program {varname: generator_op_id} self.forward_non_leaf_tensors = {} # record the cast ops that are inserted for a forward @@ -126,7 +162,7 @@ class FP16State(object): def _build_state(self): """ - mark the execution mode (fp16 or fp32) for ops in all blocks + mark the execution mode (fp16 or fp32) for ops in all blocks include forward ops & backward ops """ # mark op dtype @@ -156,8 +192,9 @@ class FP16State(object): if op.type == "assign" and "array_" in op.input_arg_names[0]: self._op_fp16_dict[op.desc.original_id()] = False return - if _need_keep_fp32(op, self.amp_list.unsupported_list, - self.use_fp16_guard): + if _need_keep_fp32( + op, self.amp_list.unsupported_list, self.use_fp16_guard + ): self._op_fp16_dict[op.desc.original_id()] = False else: self._op_fp16_dict[op.desc.original_id()] = True @@ -170,8 +207,9 @@ class FP16State(object): if op.desc.original_id() in self.grad_op_to_op_map: fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()] assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op)) - self._op_fp16_dict[ - op.desc.original_id()] = self._op_fp16_dict[fwd_op_id] + self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[ + fwd_op_id + ] if int(op.attr('op_role')) == 257: self.is_train = True @@ -181,7 +219,8 @@ class FP16State(object): try: var = block.var(var_name) except ValueError as e: - var = self.program.global_block().var(var_name) + var = block._var_recursive(var_name) + # var = self.program.global_block().var(var_name) # NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is # a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY @@ -196,13 +235,18 @@ class FP16State(object): for op in block.ops: if is_forward_op(op): # NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python - if self._is_fp16_op(op.desc.original_id()) == True \ - or op.type == "cast": + if ( + self._is_fp16_op(op.desc.original_id()) == True + or op.type == "cast" + ): for in_name in op.input_names: if _keep_fp32_input(op, in_name): continue for in_var_name in op.input(in_name): - if in_var_name not in self.forward_non_leaf_tensors and in_var_name not in self.input_data_var_names: + if ( + in_var_name not in self.forward_non_leaf_tensors + and in_var_name not in self.input_data_var_names + ): self.set_var_to_fp16(in_var_name, block) for out_name in op.output_names: if _keep_fp32_output(op, out_name): @@ -248,22 +292,42 @@ class FP16State(object): elif is_forward_op(op): if self._is_fp16_op(op.desc.original_id()) == False: num_cast_ops = self._insert_forward_cast_ops( - op, idx, block, core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32, self.dist_context) + op, + idx, + block, + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32, + self.dist_context, + ) elif self._is_fp16_op(op.desc.original_id()) == True: num_cast_ops = self._insert_forward_cast_ops( - op, idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, self.dist_context) + op, + idx, + block, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, + self.dist_context, + ) elif is_backward_op(op): if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: if self._is_fp16_op(op.desc.original_id()) == False: num_cast_ops = self._insert_backward_cast_ops( - op, idx, block, core.VarDesc.VarType.FP16, - core.VarDesc.VarType.FP32, self.dist_context) + op, + idx, + block, + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32, + self.dist_context, + ) elif self._is_fp16_op(op.desc.original_id()) == True: num_cast_ops = self._insert_backward_cast_ops( - op, idx, block, core.VarDesc.VarType.FP32, - core.VarDesc.VarType.FP16, self.dist_context) + op, + idx, + block, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, + self.dist_context, + ) elif op.type == "sum": # all inputs dtype of sum should be equal and output dtype should follow input out_var_name = op.output_arg_names[0] @@ -271,41 +335,51 @@ class FP16State(object): out_var = block.var(out_var_name) in_var = block._find_var_recursive(in_var_name) for in_var_name in op.input_arg_names: - assert in_var.dtype == block.var( - in_var_name).dtype, "{}, {}, {}".format( - in_var, block.var(in_var_name), str(op)) + assert ( + in_var.dtype == block.var(in_var_name).dtype + ), "{}, {}, {}".format( + in_var, block.var(in_var_name), str(op) + ) out_var.desc.set_dtype(in_var.dtype) idx += num_cast_ops + 1 block._sync_with_cpp() - def _insert_forward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, - dist_context): + def _insert_forward_cast_ops( + self, op, idx, block, src_dtype, dst_dtype, dist_context + ): num_cast_ops = 0 for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( - op, in_name): + op, in_name + ): continue consume_op_attr = dist_context.get_op_dist_attr_for_program(op) assert consume_op_attr is not None for in_var_name in op.input(in_name): in_var = block._find_var_recursive(in_var_name) - if in_var is None or in_var.type not in _valid_types or in_var.dtype == dst_dtype: + if ( + in_var is None + or in_var.type not in _valid_types + or in_var.dtype == dst_dtype + ): continue if in_var.dtype == src_dtype: - cast_name = in_var.name + '.cast_' + _dtype_to_str( - dst_dtype) + cast_name = ( + in_var.name + '.cast_' + _dtype_to_str(dst_dtype) + ) cast_var = block.vars.get(cast_name) self.forward_input_cast_ops[op.desc.original_id()] += [ (cast_name, in_var.name, dst_dtype, src_dtype, in_name) ] in_var_dist_attr = consume_op_attr.get_input_dist_attr( - in_var.name) + in_var.name + ) assert in_var_dist_attr is not None # truly insert cast op if cast_var is None or cast_var.dtype != dst_dtype: @@ -319,9 +393,11 @@ class FP16State(object): name=cast_name, dtype=dst_dtype, persistable=False, - stop_gradient=in_var.stop_gradient) - set_var_dist_attr(dist_context, cast_var, ref_mapping, - ref_mesh) + stop_gradient=in_var.stop_gradient, + ) + set_var_dist_attr( + dist_context, cast_var, ref_mapping, ref_mesh + ) cast_op = block._insert_op_without_sync( idx, @@ -331,23 +407,27 @@ class FP16State(object): attrs={ "in_dtype": in_var.dtype, "out_dtype": cast_var.dtype, - OP_ROLE_KEY: OpRole.Forward - }) + OP_ROLE_KEY: OpRole.Forward, + }, + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, ref_mapping, dist_context) + cast_op, ref_mesh, ref_mapping, dist_context + ) num_cast_ops += 1 op._rename_input(in_var.name, cast_name) - consume_op_attr.set_input_dist_attr(cast_name, - in_var_dist_attr) + consume_op_attr.set_input_dist_attr( + cast_name, in_var_dist_attr + ) if op.has_attr('out_dtype') and op.attr('out_dtype') != -1: assert op.attr('out_dtype') == dst_dtype return num_cast_ops - def _insert_backward_cast_ops(self, op, idx, block, src_dtype, dst_dtype, - dist_context): + def _insert_backward_cast_ops( + self, op, idx, block, src_dtype, dst_dtype, dist_context + ): num_cast_ops = 0 op_id = op.desc.id() @@ -363,15 +443,21 @@ class FP16State(object): if _keep_fp32_output(op, out_var.name): continue assert out_var.dtype == dst_dtype, "{}, {}".format( - str(out_var), dst_dtype) + str(out_var), dst_dtype + ) - for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[ - forward_op_id]: + for ( + cast_name, + src_name, + dst_dtype, + src_dtype, + slot_name, + ) in self.forward_input_cast_ops[forward_op_id]: # rename input assert src_name in op.input( - slot_name), "var: {} not in op's {}. {}".format( - src_name, slot_name, str(op)) + slot_name + ), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op)) src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name) assert src_var_dist_attr is not None op._rename_input(src_name, cast_name) @@ -393,15 +479,18 @@ class FP16State(object): ref_mapping = grad_dist_attr.dims_mapping cast_grad = block.create_var( - name=unique_name.generate_with_ignorable_key("".join( - [cast_name, '@GRAD'])), + name=unique_name.generate_with_ignorable_key( + "".join([cast_name, '@GRAD']) + ), dtype=dst_dtype, shape=grad.shape, type=grad.type, persistable=grad.persistable, - stop_gradient=grad.stop_gradient) + stop_gradient=grad.stop_gradient, + ) dist_context.set_tensor_dist_attr_for_program( - cast_grad, grad_dist_attr) + cast_grad, grad_dist_attr + ) op._rename_output(grad_name, cast_grad.name) grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr) @@ -414,12 +503,14 @@ class FP16State(object): attrs={ "in_dtype": dst_dtype, "out_dtype": src_dtype, - OP_ROLE_KEY: OpRole.Backward - }) + OP_ROLE_KEY: OpRole.Backward, + }, + ) grad.desc.set_dtype(src_dtype) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - cast_op, ref_mesh, ref_mapping, dist_context) + cast_op, ref_mesh, ref_mapping, dist_context + ) num_cast_ops += 1 return num_cast_ops @@ -432,26 +523,34 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') for e in grads: - check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], - 'check_finite_and_unscale') + check_variable_and_dtype( + e, + "x", + ['float16', 'float32', 'float64'], + 'check_finite_and_unscale', + ) found_inf = main_block.create_var( - name=unique_name.generate_with_ignorable_key(".".join( - ['find_infinite_scale', name])), + name=unique_name.generate_with_ignorable_key( + ".".join(['find_infinite_scale', name]) + ), shape=[1], dtype='bool', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} attrs = {'op_role': OpRole.Optimize} - new_op = main_block.append_op(type='check_finite_and_unscale', - inputs=inputs, - outputs=outputs, - attrs=attrs) + new_op = main_block.append_op( + type='check_finite_and_unscale', + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = world_process_group.ranks @@ -461,10 +560,12 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context): for g in grads: g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None - new_op_dist_attr.set_input_dims_mapping(g.name, - g_dist_attr.dims_mapping) - new_op_dist_attr.set_output_dims_mapping(g.name, - g_dist_attr.dims_mapping) + new_op_dist_attr.set_input_dims_mapping( + g.name, g_dist_attr.dims_mapping + ) + new_op_dist_attr.set_output_dims_mapping( + g.name, g_dist_attr.dims_mapping + ) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) return grads, found_inf @@ -473,8 +574,9 @@ def _split_grads(params_grads): grads = [g for _, g in params_grads] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] - assert len(fp32_grads) + len(fp16_grads) == len(grads), \ - "Data types of all grads must be either fp16 or fp32." + assert len(fp32_grads) + len(fp16_grads) == len( + grads + ), "Data types of all grads must be either fp16 or fp32." return grads, fp32_grads, fp16_grads @@ -486,37 +588,45 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context): var = block.var(var_name) var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) assert var_dist_attr is not None - new_op_dist_attr.set_input_dims_mapping(var_name, - var_dist_attr.dims_mapping) + new_op_dist_attr.set_input_dims_mapping( + var_name, var_dist_attr.dims_mapping + ) for var_name in new_op.output_arg_names: var = block.var(var_name) var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) assert var_dist_attr is not None - new_op_dist_attr.set_output_dims_mapping(var_name, - var_dist_attr.dims_mapping) + new_op_dist_attr.set_output_dims_mapping( + var_name, var_dist_attr.dims_mapping + ) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) def _get_memcopy_idx(block, found_inf_var): # use reduce_any op for check_nan_inf as the anchor for now for idx, op in enumerate(block.ops): - if op.type == 'reduce_any' and op.output_arg_names[ - 0] == found_inf_var.name: + if ( + op.type == 'reduce_any' + and op.output_arg_names[0] == found_inf_var.name + ): return idx + 1 raise RuntimeError( - "not found the correct location for memcopy for found_inf_var.") + "not found the correct location for memcopy for found_inf_var." + ) def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): src_name = src_var.name - output_var = block.create_var(name=unique_name.generate_with_ignorable_key( - src_name.join(['memcopy_'])), - dtype=src_var.dtype, - shape=src_var.shape, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=src_var.stop_gradient) + output_var = block.create_var( + name=unique_name.generate_with_ignorable_key( + src_name.join(['memcopy_']) + ), + dtype=src_var.dtype, + shape=src_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=src_var.stop_gradient, + ) set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks) @@ -527,16 +637,20 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): dst_place_type = 1 else: raise NotImplementedError( - "direction [{}] is not supported yet.".format(direction)) + "direction [{}] is not supported yet.".format(direction) + ) attrs = {'dst_place_type': dst_place_type} - new_op = block._insert_op_without_sync(index=idx, - type='memcpy', - inputs={'X': [src_var]}, - outputs={'Out': [output_var]}, - attrs=attrs) - _set_op_dist_attr_with_ranks(new_op, world_process_group.ranks, block, - dist_context) + new_op = block._insert_op_without_sync( + index=idx, + type='memcpy', + inputs={'X': [src_var]}, + outputs={'Out': [output_var]}, + attrs=attrs, + ) + _set_op_dist_attr_with_ranks( + new_op, world_process_group.ranks, block, dist_context + ) block._sync_with_cpp() return output_var @@ -564,19 +678,21 @@ def cast_startup_program(): for op in startup_program.global_block().ops: if is_initialization_op(op): output_name = op.output_arg_names[0] - if param_to_dtype.get(output_name, - None) == core.VarDesc.VarType.FP16: + if ( + param_to_dtype.get(output_name, None) + == core.VarDesc.VarType.FP16 + ): assert op.has_attr( 'dtype' ), "initialization op is supported to has dtype attribute but got {}.".format( - str(op)) + str(op) + ) if op.attr('dtype') == core.VarDesc.VarType.FP32: op._set_attr('dtype', core.VarDesc.VarType.FP16) @register_pass("auto_parallel_fp16") class FP16Pass(AMPPass): - def __init__(self): super(FP16Pass, self).__init__() @@ -589,16 +705,22 @@ class FP16Pass(AMPPass): amp_list = AutoMixedPrecisionLists( set(self.get_attr("custom_white_list")), - set(self.get_attr("custom_black_list")), None) + set(self.get_attr("custom_black_list")), + None, + ) # NOTE don't not change input data dtype, since it is controled by dataloader # and which is out of control of FP16 Pass input_data_var_names = [var.name for var in self.get_attr("input_data")] with paddle.static.program_guard(main_program, startup_program): - fp16_state = FP16State(main_program, amp_list, self.dist_context, - self.get_attr("use_fp16_guard"), - input_data_var_names) + fp16_state = FP16State( + main_program, + amp_list, + self.dist_context, + self.get_attr("use_fp16_guard"), + input_data_var_names, + ) is_train = fp16_state._build_state() cast_startup_program() @@ -611,41 +733,63 @@ class FP16Pass(AMPPass): grads, fp32_grads, fp16_grads = _split_grads(params_grads) - if self.get_attr("use_dynamic_loss_scaling" - ) or self.get_attr("init_loss_scaling") != 1.0: + if ( + self.get_attr("use_dynamic_loss_scaling") + or self.get_attr("init_loss_scaling") != 1.0 + ): found_infs = [] if fp32_grads: with main_program._optimized_guard([]): _, found_inf_fp32 = _check_and_update_gradient( - fp32_grads, self._loss_scaling, "@fp32", - self.dist_context) + fp32_grads, + self._loss_scaling, + "@fp32", + self.dist_context, + ) found_infs.append(found_inf_fp32) if fp16_grads: with main_program._optimized_guard([]): _, found_inf_fp16 = _check_and_update_gradient( - fp16_grads, self._loss_scaling, "@fp16", - self.dist_context) + fp16_grads, + self._loss_scaling, + "@fp16", + self.dist_context, + ) found_infs.append(found_inf_fp16) with main_program._optimized_guard([]): block = main_program.global_block() all_infs = paddle.fluid.layers.concat(found_infs) - set_var_dist_attr(self.dist_context, all_infs, [-1], - world_process_group.ranks) + set_var_dist_attr( + self.dist_context, + all_infs, + [-1], + world_process_group.ranks, + ) new_op = block.ops[-1] assert new_op.type == "concat" - _set_op_dist_attr_with_ranks(new_op, - world_process_group.ranks, - block, self.dist_context) + _set_op_dist_attr_with_ranks( + new_op, + world_process_group.ranks, + block, + self.dist_context, + ) found_inf = paddle.fluid.layers.reduce_any(all_infs) - set_var_dist_attr(self.dist_context, found_inf, [-1], - world_process_group.ranks) + set_var_dist_attr( + self.dist_context, + found_inf, + [-1], + world_process_group.ranks, + ) new_op = block.ops[-1] assert new_op.type == "reduce_any" - _set_op_dist_attr_with_ranks(new_op, - world_process_group.ranks, - block, self.dist_context) + _set_op_dist_attr_with_ranks( + new_op, + world_process_group.ranks, + block, + self.dist_context, + ) if self.get_attr("use_dynamic_loss_scaling"): with main_program._optimized_guard([]): @@ -660,14 +804,15 @@ class FP16Pass(AMPPass): if self.get_attr("use_optimizer_fp16"): base_opt._multi_precision = False if isinstance( - base_opt, - (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)): + base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW) + ): with main_program._optimized_guard([]): # found_inf = paddle.tensor.creation._memcpy( # found_inf, paddle.CPUPlace()) insert_idx = _get_memcopy_idx(block, found_inf) - found_inf = _insert_memcopy(block, insert_idx, found_inf, - self.dist_context) + found_inf = _insert_memcopy( + block, insert_idx, found_inf, self.dist_context + ) base_opt._set_auxiliary_var('found_inf', found_inf.name) elif hasattr(base_opt, "_set_auxiliary_var"): base_opt._set_auxiliary_var('found_inf', found_inf.name) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 3d34ed4fcdbe7cd7538b9801bb96eba2df70afa7..bd6ccfd3922c84b94d3eaf72a175f8bf113a9f49 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) set_tests_properties(test_engine_callbacks PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS + ${dist_ENVS}) + set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120) + py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full + ENVS ${dist_ENVS}) + set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120) + py_test_modules(test_parallel_tuner_predict MODULES + test_parallel_tuner_predict ENVS ${dist_ENVS}) + set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) @@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS}) py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS}) py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS}) + py_test_modules(test_cluster_v2 MODULES test_cluster_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) @@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_interface MODULES test_interface) py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_pass_quantization MODULES test_pass_quantization) - py_test_modules(test_dist_shape MODULES test_dist_shape) py_test_modules(test_dist_assign MODULES test_dist_assign) py_test_modules(test_conditional_block_reshard MODULES test_conditional_block_reshard) - - py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS - ${dist_ENVS}) - set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120) - py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full - ENVS ${dist_ENVS}) - set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120) - py_test_modules(test_parallel_tuner_predict MODULES - test_parallel_tuner_predict ENVS ${dist_ENVS}) - set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) + py_test_modules(test_engine_api_error MODULES test_engine_api_error) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index 1cc68393b64268ffec092666d0feb1ba3968029f..0a50c6d3000a0a617b15c5779a6fd6d290fbeaf7 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -31,7 +31,10 @@ from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader from paddle.distributed.fleet import auto -from paddle.distributed.auto_parallel.interface import get_collection, CollectionNames +from paddle.distributed.auto_parallel.interface import ( + get_collection, + CollectionNames, +) from paddle.optimizer.lr import CosineAnnealingDecay from paddle.fluid.dataloader.collate import default_collate_fn @@ -56,7 +59,6 @@ my_feed_vars = [] class MyDataset(Dataset): - def __init__(self, num_samples): super(MyDataset, self).__init__() self.num_samples = num_samples @@ -77,38 +79,38 @@ def get_random_inputs_and_labels(image_shape, label_shape): def batch_generator_creator(): - def __reader__(): for _ in range(batch_num): batch_input, batch_label = get_random_inputs_and_labels( - [batch_size, image_size], [batch_size, 1]) + [batch_size, image_size], [batch_size, 1] + ) yield batch_input, batch_label return __reader__ class MLPLayer(nn.Layer): - - def __init__(self, - hidden_size=1024, - intermediate_size=4 * 1024, - dropout_ratio=0.1, - initializer_range=0.02): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02, + ): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr( - initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range) + ) bias_attr = None - self.linear0 = nn.Linear(d_model, - dim_feedforward, - weight_attr, - bias_attr=bias_attr) - self.linear1 = nn.Linear(dim_feedforward, - d_model, - weight_attr, - bias_attr=bias_attr) + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr + ) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr + ) self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") @@ -132,16 +134,20 @@ class MLPLayer(nn.Layer): def train_high_level(fetch): global is_fetch is_fetch = fetch - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) metric = paddle.metric.Accuracy() strategy = auto.Strategy() @@ -153,11 +159,13 @@ def train_high_level(fetch): train_dataset = MyDataset(batch_num * batch_size) eval_dataset1 = MyDataset(5 * batch_size) - history = engine.fit(train_data=train_dataset, - epochs=2, - batch_size=batch_size, - valid_data=eval_dataset1, - log_freq=1) + history = engine.fit( + train_data=train_dataset, + epochs=2, + batch_size=batch_size, + valid_data=eval_dataset1, + log_freq=1, + ) # eval eval_dataset2 = MyDataset(batch_size) @@ -176,16 +184,20 @@ def train_high_level(fetch): def train_low_level(): - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) metric = paddle.metric.Accuracy() strategy = auto.Strategy() @@ -200,18 +212,18 @@ def train_low_level(): # Build normal normal dataloader # train train_dataset = MyDataset(batch_num * batch_size) - train_dataloader = engine.dataloader(train_dataset, - batch_size=batch_size, - mode="train") + train_dataloader = engine.dataloader( + train_dataset, batch_size=batch_size, mode="train" + ) engine.prepare(mode="train") for data in train_dataloader: outs = engine.run(data, feed=feed_dict, mode="train") # eval eval_dataset2 = MyDataset(batch_size) - eval_dataloader = engine.dataloader(eval_dataset2, - batch_size=batch_size, - mode="eval") + eval_dataloader = engine.dataloader( + eval_dataset2, batch_size=batch_size, mode="eval" + ) engine.prepare(mode="eval") for data in eval_dataloader: outs = engine.run(data, feed=feed_dict, mode="eval") @@ -234,9 +246,9 @@ def train_low_level(): # Build dataloader from generator # train train_dataset = MyDataset(batch_num * batch_size) - train_dataloader = engine.dataloader_from_generator(train_dataset, - batch_size=batch_size, - mode="train") + train_dataloader = engine.dataloader_from_generator( + train_dataset, batch_size=batch_size, mode="train" + ) engine.prepare(mode="train") for data in train_dataloader: outs = engine.run(data, feed=feed_dict, mode="train") @@ -244,17 +256,18 @@ def train_low_level(): # eval engine.to_mode("eval") eval_dataset2 = MyDataset(batch_size) - eval_dataloader = engine.dataloader_from_generator(eval_dataset2, - batch_size=batch_size) + eval_dataloader = engine.dataloader_from_generator( + eval_dataset2, batch_size=batch_size + ) engine.prepare() for data in eval_dataloader: outs = engine.run(data, feed=feed_dict) # predict test_dataset = MyDataset(batch_size) - predict_dataloader = engine.dataloader_from_generator(test_dataset, - batch_size=batch_size, - mode="predict") + predict_dataloader = engine.dataloader_from_generator( + test_dataset, batch_size=batch_size, mode="predict" + ) engine.prepare(mode="predict") for data in predict_dataloader: outs = engine.run(data, feed=feed_dict, mode="predict") @@ -268,16 +281,20 @@ def train_low_level(): def train_builtin_data_vars(): - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) metric = paddle.metric.Accuracy() strategy = auto.Strategy() @@ -295,9 +312,9 @@ def train_builtin_data_vars(): with static.program_guard(engine.main_program, engine.startup_program): feed_list = engine.inputs + engine.labels print(feed_list) - loader = paddle.io.DataLoader.from_generator(feed_list=feed_list, - capacity=4 * batch_size, - iterable=False) + loader = paddle.io.DataLoader.from_generator( + feed_list=feed_list, capacity=4 * batch_size, iterable=False + ) places = static.cuda_places() loader.set_batch_generator(batch_generator_creator(), places=places) @@ -308,36 +325,40 @@ def train_builtin_data_vars(): while True: engine.run() except paddle.fluid.core.EOFException: - loader.reset( - ) # call DataLoader.reset() after catching EOFException + loader.reset() # call DataLoader.reset() after catching EOFException def train_non_builtin_data_vars(): main_program = static.Program() startup_program = static.Program() - with static.program_guard(main_program, - startup_program), utils.unique_name.guard(): - input = static.data(name="input", - shape=[batch_size, image_size], - dtype='float32') + with static.program_guard( + main_program, startup_program + ), utils.unique_name.guard(): + input = static.data( + name="input", shape=[batch_size, image_size], dtype='float32' + ) label = static.data(name="label", shape=[batch_size, 1], dtype='int64') - loader = paddle.io.DataLoader.from_generator(feed_list=[input, label], - capacity=4 * batch_size, - iterable=False) + loader = paddle.io.DataLoader.from_generator( + feed_list=[input, label], capacity=4 * batch_size, iterable=False + ) places = static.cuda_places() loader.set_batch_generator(batch_generator_creator(), places=places) - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) metric = paddle.metric.Accuracy() predict = mlp(input) loss_var = loss(predict, label) @@ -345,53 +366,109 @@ def train_non_builtin_data_vars(): strategy = auto.Strategy() strategy.auto_mode = "semi" - engine = auto.Engine(loss=loss_var, - optimizer=optimizer, - metrics=metric, - strategy=strategy) + engine = auto.Engine( + loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy + ) # train engine.to_mode("train") - engine.prepare(inputs=[input], - labels=[label], - main_program=main_program, - startup_program=startup_program) + engine.prepare( + inputs=[input], + labels=[label], + main_program=main_program, + startup_program=startup_program, + ) for _ in range(epoch_num): loader.start() # call DataLoader.start() before each epoch starts try: while True: engine.run() except paddle.fluid.core.EOFException: - loader.reset( - ) # call DataLoader.reset() after catching EOFException + loader.reset() # call DataLoader.reset() after catching EOFException def get_cost(): + main_program = static.Program() + startup_program = static.Program() + with static.program_guard( + main_program, startup_program + ), utils.unique_name.guard(): + input = static.data( + name="input", shape=[batch_size, image_size], dtype='float32' + ) + label = static.data(name="label", shape=[batch_size, 1], dtype='int64') + + loader = paddle.io.DataLoader.from_generator( + feed_list=[input, label], capacity=4 * batch_size, iterable=False + ) + places = static.cuda_places() + loader.set_batch_generator(batch_generator_creator(), places=places) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) + metric = paddle.metric.Accuracy() + predict = mlp(input) + loss_var = loss(predict, label) + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + engine = auto.Engine( + loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy + ) + engine.prepare( + main_program=main_program, + startup_program=startup_program, + inputs=[input], + labels=[label], + mode="train", + ) + engine.cost() + + +def get_cost_by_default_program(): main_program = static.default_main_program() startup_program = static.default_startup_program() - with static.program_guard(main_program, - startup_program), utils.unique_name.guard(): - input = static.data(name="input", - shape=[batch_size, image_size], - dtype='float32') + with static.program_guard( + main_program, startup_program + ), utils.unique_name.guard(): + input = static.data( + name="input", shape=[batch_size, image_size], dtype='float32' + ) label = static.data(name="label", shape=[batch_size, 1], dtype='int64') - loader = paddle.io.DataLoader.from_generator(feed_list=[input, label], - capacity=4 * batch_size, - iterable=False) + loader = paddle.io.DataLoader.from_generator( + feed_list=[input, label], capacity=4 * batch_size, iterable=False + ) places = static.cuda_places() loader.set_batch_generator(batch_generator_creator(), places=places) - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) metric = paddle.metric.Accuracy() predict = mlp(input) loss_var = loss(predict, label) @@ -399,24 +476,27 @@ def get_cost(): strategy = auto.Strategy() strategy.auto_mode = "semi" - engine = auto.Engine(loss=loss_var, - optimizer=optimizer, - metrics=metric, - strategy=strategy) - engine.cost() + engine = auto.Engine( + loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy + ) + engine.cost(mode="train") def get_cost_by_spec(): - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) metric = paddle.metric.Accuracy() strategy = auto.Strategy() @@ -436,4 +516,5 @@ if __name__ == "__main__": train_builtin_data_vars() train_non_builtin_data_vars() get_cost() + get_cost_by_default_program() get_cost_by_spec() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api_error.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api_error.py new file mode 100644 index 0000000000000000000000000000000000000000..cd825524b8ae30131ec3f779506959b166a05add --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api_error.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.static as static +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.io import Dataset + +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +epoch_num = 1 +batch_size = 2 +batch_num = 10 +hidden_size = 1024 +sequence_len = 512 +image_size = hidden_size +class_num = 10 + +is_fetch = True +is_feed = True +my_feed_vars = [] + + +class TrainDataset(Dataset): + def __init__(self, num_samples): + super(TrainDataset, self).__init__() + self.num_samples = num_samples + + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + label = np.random.randint(0, class_num - 1, dtype="int64") + return input, label + + def __len__(self): + return self.num_samples + + +class TestDataset(Dataset): + def __init__(self, num_samples): + super(TestDataset, self).__init__() + self.num_samples = num_samples + + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + return input + + def __len__(self): + return self.num_samples + + +class MLPLayer(nn.Layer): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02, + ): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range) + ) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr + ) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr + ) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + + if is_feed: + my_feed_vars.append((out, out.shape)) + + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + out = self.linear2(out) + + if is_feed: + my_feed_vars.append((out, out.shape)) + if is_fetch: + auto.fetch(out, "my_fetch", logging=True) + return out + + +class TestEngineErrorRaise(unittest.TestCase): + def setUp(self): + class NoSupportData1: + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + label = np.random.randint(0, class_num - 1, dtype="int64") + return input, label + + class NoSupportData2(TrainDataset): + def __getitem__(self, index): + input = [ + list(np.random.uniform(size=image_size).astype("float32")) + ] + label = [np.random.randint(0, class_num - 1, dtype="int64")] + return input, label + + class NoSupportData3: + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + return input + + class NoSupportData4(TestDataset): + def __getitem__(self, index): + input = [ + list(np.random.uniform(size=image_size).astype("float32")) + ] + return input + + self.no_support_data_1 = NoSupportData1() + self.no_support_data_2 = NoSupportData2(10) + self.no_support_data_3 = NoSupportData3() + self.no_support_data_4 = NoSupportData4(10) + + def test_Engine(self): + with self.assertRaises(TypeError): + auto.Engine(model=paddle.static.Program()) + with self.assertRaises(TypeError): + auto.Engine(loss="CrossEntropyLoss") + with self.assertRaises(TypeError): + auto.Engine(optimizer="adam") + with self.assertRaises(TypeError): + auto.Engine(metrics=["acc"]) + with self.assertRaises(TypeError): + auto.Engine(cluster="cluster") + with self.assertRaises(TypeError): + auto.Engine(strategy="strategy") + + def test_fit(self): + + with self.assertRaises(TypeError): + + engine = auto.Engine( + model=MLPLayer(), + loss=paddle.nn.CrossEntropyLoss(), + optimizer=paddle.optimizer.AdamW(0.00001), + ) + engine.fit(train_data=self.no_support_data_1) + + with self.assertRaises(TypeError): + + engine = auto.Engine( + model=MLPLayer(), + loss=paddle.nn.CrossEntropyLoss(), + optimizer=paddle.optimizer.AdamW(0.00001), + ) + engine.fit(train_data=self.no_support_data_2) + + def test_evaluate(self): + with self.assertRaises(TypeError): + + engine = auto.Engine( + model=MLPLayer(), + loss=paddle.nn.CrossEntropyLoss(), + metrics=paddle.metric.Accuracy(), + ) + engine.evaluate(valid_data=self.no_support_data_3) + + with self.assertRaises(TypeError): + + engine = auto.Engine( + model=MLPLayer(), + loss=paddle.nn.CrossEntropyLoss(), + metrics=paddle.metric.Accuracy(), + ) + engine.evaluate( + valid_data=self.no_support_data_4, valid_sample_split=1 + ) + + def test_predict(self): + with self.assertRaises(TypeError): + + engine = auto.Engine(model=MLPLayer()) + engine.predict( + test_data=self.no_support_data_3, test_sample_split=1 + ) + + with self.assertRaises(TypeError): + + engine = auto.Engine(model=MLPLayer()) + engine.predict( + test_data=self.no_support_data_4, test_sample_split=1 + ) + + def build_program(self): + main_prog = static.Program() + startup_prog = static.Program() + with static.program_guard(main_prog, startup_prog): + input = static.data( + name="input", + shape=[batch_size // 2, image_size], + dtype='float32', + ) + label = static.data( + name="label", shape=[batch_size // 2, 1], dtype='int64' + ) + mlp = MLPLayer() + loss = paddle.nn.CrossEntropyLoss() + predict = mlp(input) + loss_var = loss(predict, label) + return main_prog, startup_prog, input, label, loss_var + + def test_prepare(self): + with self.assertRaises(ValueError): + engine = auto.Engine(model=MLPLayer()) + engine.prepare() + + with self.assertRaises(AssertionError): + engine = auto.Engine(model=MLPLayer()) + engine.prepare(mode="train") + + with self.assertRaises(TypeError): + input = static.data( + name="input", + shape=[batch_size / 2, image_size], + dtype='float32', + ) + label = static.data( + name="label", shape=[batch_size / 2, 1], dtype='int64' + ) + engine = auto.Engine(model=MLPLayer()) + engine.prepare(inputs_spec=input, labels_spec=label, mode="eval") + + input_spec = static.InputSpec( + shape=[batch_size, image_size], dtype="float32", name="input" + ) + label_spec = static.InputSpec( + shape=[batch_size, image_size], dtype="float32", name="input" + ) + ( + main_prog, + startup_prog, + input_var, + label_var, + loss_var, + ) = self.build_program() + + with self.assertRaises(TypeError): + engine = auto.Engine(loss=loss_var) + engine.prepare( + inputs=input_spec, + labels=label_spec, + main_program=main_prog, + startup_program=startup_prog, + mode="eval", + ) + + with self.assertRaises(AssertionError): + engine = auto.Engine(loss=loss_var) + engine.prepare( + inputs_spec=[input_spec, input_spec], + labels_spec=[label_spec, label_spec], + inputs=input_var, + labels=label_var, + main_program=main_prog, + startup_program=startup_prog, + mode="predict", + ) + + def test_cost(self): + with self.assertRaises(ValueError): + engine = auto.Engine(model=MLPLayer()) + engine.cost(mode="predict") + + +class TestEngineDynamicErrorRaise(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_cost(self): + with self.assertRaises(ValueError): + engine = auto.Engine(model=MLPLayer()) + engine.cost(mode="predict") + + +if __name__ == "__main__": + unittest.main()