未验证 提交 4b3589fb 编写于 作者: Z zhaoyingli 提交者: GitHub

2.4/fix engine build (#47462)

* update codestyle

* [AutoParallel] fix fp16 for subblock (#47189)

* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment

* [AutoParallel] fix engine _build and cost method (#47263)

* fix engine build method

* fix import

* update engine cost

* update raise error

* update cmakelist

* revert optimizer

* revert optimizer

* fix unittest

* fix unittest
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 f93e9a58
......@@ -20,9 +20,28 @@ class AdamOpCost(CompOpCost):
OP_TYPE = "adam"
def __init__(self, op=None, op_desc=None, cluster=None):
super(AdamOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(AdamOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ArgsortOpCost(CompOpCost):
OP_TYPE = "argsort"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ArgsortOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -39,9 +58,9 @@ class AssignOpCost(CompOpCost):
OP_TYPE = "assign"
def __init__(self, op=None, op_desc=None, cluster=None):
super(AssignOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(AssignOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -58,9 +77,9 @@ class AssignValueOpCost(CompOpCost):
OP_TYPE = "assign_value"
def __init__(self, op=None, op_desc=None, cluster=None):
super(AssignValueOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(AssignValueOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -77,9 +96,9 @@ class BeamSearchOpCost(CompOpCost):
OP_TYPE = "beam_search"
def __init__(self, op=None, op_desc=None, cluster=None):
super(BeamSearchOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(BeamSearchOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -96,9 +115,9 @@ class BeamSearchDecodeOpCost(CompOpCost):
OP_TYPE = "beam_search_decode"
def __init__(self, op=None, op_desc=None, cluster=None):
super(BeamSearchDecodeOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(BeamSearchDecodeOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -115,9 +134,9 @@ class CastOpCost(CompOpCost):
OP_TYPE = "cast"
def __init__(self, op=None, op_desc=None, cluster=None):
super(CastOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(CastOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -134,9 +153,9 @@ class ConcatOpCost(CompOpCost):
OP_TYPE = "concat"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ConcatOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ConcatOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -153,9 +172,9 @@ class DropoutOpCost(CompOpCost):
OP_TYPE = "dropout"
def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(DropoutOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -172,9 +191,9 @@ class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(DropoutGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -191,9 +210,9 @@ class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseAddOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseAddOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -210,9 +229,9 @@ class ElementwiseAddGradOpCost(CompOpCost):
OP_TYPE = "elementwise_add_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseAddGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseAddGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -229,9 +248,9 @@ class ElementwiseDivOpCost(CompOpCost):
OP_TYPE = "elementwise_div"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseDivOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseDivOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -248,9 +267,9 @@ class ElementwiseDivGradOpCost(CompOpCost):
OP_TYPE = "elementwise_div_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseDivGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseDivGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -267,9 +286,9 @@ class ElementwiseMulOpCost(CompOpCost):
OP_TYPE = "elementwise_mul"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseMulOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseMulOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -286,9 +305,9 @@ class ElementwiseMulGradOpCost(CompOpCost):
OP_TYPE = "elementwise_mul_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseMulGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseMulGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -305,9 +324,9 @@ class ElementwiseSubOpCost(CompOpCost):
OP_TYPE = "elementwise_sub"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseSubOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseSubOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -324,9 +343,28 @@ class ElementwiseSubGradOpCost(CompOpCost):
OP_TYPE = "elementwise_sub_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ElementwiseSubGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ElementwiseSubGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class EqualOpCost(CompOpCost):
OP_TYPE = "equal"
def __init__(self, op=None, op_desc=None, cluster=None):
super(EqualOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -343,9 +381,9 @@ class EmbeddingOpCost(CompOpCost):
OP_TYPE = "c_embedding"
def __init__(self, op=None, op_desc=None, cluster=None):
super(EmbeddingOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(EmbeddingOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -362,9 +400,9 @@ class EmbeddingGradOpCost(CompOpCost):
OP_TYPE = "c_embedding_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(EmbeddingGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(EmbeddingGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -381,9 +419,9 @@ class FillConstantOpCost(CompOpCost):
OP_TYPE = "fill_constant"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FillConstantOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(FillConstantOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -400,9 +438,9 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
OP_TYPE = "fill_constant_batch_size_like"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FillConstantBatchSizeLikeOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(FillConstantBatchSizeLikeOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -419,8 +457,9 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
super(FusedSoftmaxMaskUpperTriangleOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -437,8 +476,9 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
super(FusedSoftmaxMaskUpperTriangleGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -455,9 +495,9 @@ class GatherOpCost(CompOpCost):
OP_TYPE = "gather"
def __init__(self, op=None, op_desc=None, cluster=None):
super(GatherOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(GatherOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -474,9 +514,9 @@ class GeluOpCost(CompOpCost):
OP_TYPE = "gelu"
def __init__(self, op=None, op_desc=None, cluster=None):
super(GeluOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(GeluOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -493,9 +533,9 @@ class GeluGradOpCost(CompOpCost):
OP_TYPE = "gelu_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(GeluGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(GeluGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -512,9 +552,9 @@ class GreaterEqualOpCost(CompOpCost):
OP_TYPE = "greater_equal"
def __init__(self, op=None, op_desc=None, cluster=None):
super(GreaterEqualOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(GreaterEqualOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -531,9 +571,9 @@ class IncrementOpCost(CompOpCost):
OP_TYPE = "increment"
def __init__(self, op=None, op_desc=None, cluster=None):
super(IncrementOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(IncrementOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -546,9 +586,9 @@ class IsEmptyOpCost(CompOpCost):
OP_TYPE = "is_empty"
def __init__(self, op=None, op_desc=None, cluster=None):
super(IsEmptyOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(IsEmptyOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -561,9 +601,9 @@ class LayerNormOpCost(CompOpCost):
OP_TYPE = "layer_norm"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LayerNormOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LayerNormOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -580,9 +620,9 @@ class LayerNormGradOpCost(CompOpCost):
OP_TYPE = "layer_norm_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LayerNormGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LayerNormGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -599,9 +639,9 @@ class LessThanOpCost(CompOpCost):
OP_TYPE = "less_than"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LessThanOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LessThanOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -618,9 +658,9 @@ class LogicalNotOpCost(CompOpCost):
OP_TYPE = "logical_not"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LogicalNotOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LogicalNotOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -637,9 +677,9 @@ class LogicalAndOpCost(CompOpCost):
OP_TYPE = "logical_and"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LogicalAndOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LogicalAndOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -656,9 +696,9 @@ class LodResetOpCost(CompOpCost):
OP_TYPE = "lod_reset"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LodResetOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LodResetOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -692,9 +732,9 @@ class LookupTableV2OpCost(CompOpCost):
OP_TYPE = "lookup_table_v2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LookupTableV2OpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LookupTableV2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -711,9 +751,9 @@ class LookupTableV2GradOpCost(CompOpCost):
OP_TYPE = "lookup_table_v2_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(LookupTableV2GradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(LookupTableV2GradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -730,9 +770,9 @@ class MatmulOpCost(CompOpCost):
OP_TYPE = "matmul"
def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(MatmulOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -749,9 +789,9 @@ class MatmulGradOpCost(CompOpCost):
OP_TYPE = "matmul_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(MatmulGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -768,9 +808,9 @@ class MatmulV2OpCost(CompOpCost):
OP_TYPE = "matmul_v2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulV2OpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(MatmulV2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -787,9 +827,9 @@ class MatmulV2GradOpCost(CompOpCost):
OP_TYPE = "matmul_v2_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(MatmulV2GradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(MatmulV2GradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -806,9 +846,9 @@ class MemcpyOpCost(CompOpCost):
OP_TYPE = "memcpy"
def __init__(self, op=None, op_desc=None, cluster=None):
super(MemcpyOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(MemcpyOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -842,9 +882,9 @@ class MulGradOpCost(CompOpCost):
OP_TYPE = "mul_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(MulGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(MulGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -861,9 +901,9 @@ class OneHotOpCost(CompOpCost):
OP_TYPE = "one_hot"
def __init__(self, op=None, op_desc=None, cluster=None):
super(OneHotOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(OneHotOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -880,9 +920,9 @@ class ReadFromArrayOpCost(CompOpCost):
OP_TYPE = "read_from_array"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ReadFromArrayOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ReadFromArrayOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -899,9 +939,9 @@ class ReduceSumOpCost(CompOpCost):
OP_TYPE = "reduce_sum"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceSumOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ReduceSumOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -918,9 +958,9 @@ class ReduceSumGradOpCost(CompOpCost):
OP_TYPE = "reduce_sum_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceSumGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ReduceSumGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -937,9 +977,9 @@ class Reshape2OpCost(CompOpCost):
OP_TYPE = "reshape2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(Reshape2OpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(Reshape2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -956,9 +996,9 @@ class Reshape2GradOpCost(CompOpCost):
OP_TYPE = "reshape2_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(Reshape2GradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(Reshape2GradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -975,9 +1015,9 @@ class ReduceMeanOpCost(CompOpCost):
OP_TYPE = "reduce_mean"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceMeanOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ReduceMeanOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -994,9 +1034,9 @@ class ReduceMeanGradOpCost(CompOpCost):
OP_TYPE = "reduce_mean_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ReduceMeanGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ReduceMeanGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1013,9 +1053,9 @@ class SamplingIdOpCost(CompOpCost):
OP_TYPE = "sampling_id"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SamplingIdOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SamplingIdOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1032,9 +1072,9 @@ class ScaleOpCost(CompOpCost):
OP_TYPE = "scale"
def __init__(self, op=None, op_desc=None, cluster=None):
super(ScaleOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(ScaleOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1051,9 +1091,9 @@ class SliceOpCost(CompOpCost):
OP_TYPE = "slice"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SliceOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SliceOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1070,9 +1110,9 @@ class SoftmaxOpCost(CompOpCost):
OP_TYPE = "softmax"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SoftmaxOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1089,9 +1129,9 @@ class SoftmaxGradOpCost(CompOpCost):
OP_TYPE = "softmax_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SoftmaxGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1108,9 +1148,9 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
OP_TYPE = "softmax_with_cross_entropy"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxWithCrossEntropyOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SoftmaxWithCrossEntropyOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1127,9 +1167,9 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
OP_TYPE = "softmax_with_cross_entropy_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SoftmaxWithCrossEntropyGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SoftmaxWithCrossEntropyGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1146,9 +1186,9 @@ class SplitOpCost(CompOpCost):
OP_TYPE = "split"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SplitOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SplitOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1165,9 +1205,9 @@ class Squeeze2OpCost(CompOpCost):
OP_TYPE = "squeeze2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(Squeeze2OpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(Squeeze2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1184,9 +1224,9 @@ class SquareOpCost(CompOpCost):
OP_TYPE = "square"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SquareOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SquareOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1203,9 +1243,9 @@ class SquareGradOpCost(CompOpCost):
OP_TYPE = "square_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(SquareGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(SquareGradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1239,9 +1279,9 @@ class TopKOpCost(CompOpCost):
OP_TYPE = "top_k"
def __init__(self, op=None, op_desc=None, cluster=None):
super(TopKOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(TopKOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1258,9 +1298,9 @@ class Transpose2OpCost(CompOpCost):
OP_TYPE = "transpose2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(Transpose2OpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(Transpose2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1277,9 +1317,9 @@ class Transpose2GradOpCost(CompOpCost):
OP_TYPE = "transpose2_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(Transpose2GradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(Transpose2GradOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1296,9 +1336,9 @@ class Unsqueeze2OpCost(CompOpCost):
OP_TYPE = "unsqueeze2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(Unsqueeze2OpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(Unsqueeze2OpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......@@ -1315,9 +1355,9 @@ class WriteToArrayOpCost(CompOpCost):
OP_TYPE = "write_to_array"
def __init__(self, op=None, op_desc=None, cluster=None):
super(WriteToArrayOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
super(WriteToArrayOpCost, self).__init__(
op=op, op_desc=op_desc, cluster=cluster
)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
......
......@@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor
class CostEstimator:
_sepical_op_type = ["fused_attention", "fused_feedforward"]
def __init__(self,
program,
cluster,
mode="modeling",
rank=None,
loop_count=10):
def __init__(
self, program, cluster, mode="modeling", rank=None, loop_count=10
):
self._program = program
self._cluster = cluster
self._check_mode(mode)
......@@ -41,7 +38,8 @@ class CostEstimator:
self._loop_count = loop_count
self._global_cost = Cost()
self._local_cost_mapping = {}
self._detailed_cost = OrderedDict(
self._detailed_cost = (
OrderedDict()
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {}
self._ordered_ops = []
......@@ -106,7 +104,8 @@ class CostEstimator:
def _check_mode(self, mode):
if mode not in ["modeling", "profiling"]:
raise ValueError(
"Just support modeling and profiling, but got {}".format(mode))
"Just support modeling and profiling, but got {}".format(mode)
)
def _is_special_var_name(self, var_name):
special_var_name = ["lod_tensor_blocking_queue_0"]
......@@ -116,6 +115,7 @@ class CostEstimator:
def _estimate_core(self, dist_context, resharder, block):
from ..reshard import get_var_with_recursion
ops = block.ops
loop_count = None
if block.desc.id != self.program.global_block().desc.id:
......@@ -132,8 +132,9 @@ class CostEstimator:
if int(op.attr('op_role')) == int(OpRole.Optimize):
continue
if op.type in [
"create_py_reader", "create_double_buffer_reader",
"read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]:
continue
......@@ -172,14 +173,16 @@ class CostEstimator:
max_time = rank_cost.time
for rank in group_ranks:
self.local_cost(
rank).time = max_time + cost.time
self.local_cost(rank).time = (
max_time + cost.time
)
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
max_time - cost_time[rank]
)
for rank in local_comp_cost:
for comp_cost in local_comp_cost[rank]:
......@@ -191,15 +194,19 @@ class CostEstimator:
processes = op_dist_attr.process_mesh.processes
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, self.cluster)
dist_op_cost = dist_impl.calc_cost(
op.attr('op_role'), dist_op, dist_context, self.cluster
)
detail["dist_op_cost"] = dist_op_cost
if dist_op_cost is None:
assert dist_op.serial_op.type in CostEstimator._sepical_op_type
assert (
dist_op.serial_op.type in CostEstimator._sepical_op_type
)
continue
for item in dist_op_cost:
if isinstance(item, list):
......@@ -217,12 +224,14 @@ class CostEstimator:
if max_time < rank_cost.time:
max_time = rank_cost.time
for rank in group_ranks:
self.local_cost(
rank).time = max_time + comm_op_cost.time
self.local_cost(rank).time = (
max_time + comm_op_cost.time
)
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
max_time - cost_time[rank]
)
elif isinstance(item, dict):
# Op just one
for rank in processes:
......@@ -247,8 +256,11 @@ class CostEstimator:
dtype_factor = 8
elif dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4
elif dtype == paddle.float16 or dtype == paddle.bfloat16 \
or dtype == paddle.int16:
elif (
dtype == paddle.float16
or dtype == paddle.bfloat16
or dtype == paddle.int16
):
dtype_factor = 2
elif dtype == paddle.int8 or dtype == paddle.uint8:
dtype_factor = 1
......@@ -270,8 +282,9 @@ class CostEstimator:
memories = {}
self.max_memories = {}
var_info = {
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
var_info = (
{}
) # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
for block in self.program.blocks:
for op in block.ops:
......@@ -280,18 +293,22 @@ class CostEstimator:
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]:
continue
dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name)
var_name
)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh,
input_dims_mapping)
key = _convert_pm_and_dm_to_str(
process_mesh, input_dims_mapping
)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
# It is even partition now
......@@ -300,21 +317,27 @@ class CostEstimator:
global_sizes = var.shape
dtype = var.dtype
sizes = DistributedTensor.get_local_sizes(
global_sizes, input_dims_mapping, process_mesh.topology,
process_mesh.processes)
global_sizes,
input_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype)
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name)
var_name
)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh,
output_dims_mapping)
key = _convert_pm_and_dm_to_str(
process_mesh, output_dims_mapping
)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
if "memory" not in var_info[var_name][key]:
......@@ -322,10 +345,14 @@ class CostEstimator:
global_sizes = var.shape
dtype = var.dtype
sizes = DistributedTensor.get_local_sizes(
global_sizes, output_dims_mapping,
process_mesh.topology, process_mesh.processes)
global_sizes,
output_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype)
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
......@@ -333,7 +360,9 @@ class CostEstimator:
has_used_vars = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]:
continue
can_free_memories = {}
......@@ -342,9 +371,11 @@ class CostEstimator:
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name)
key = _convert_pm_and_dm_to_str(process_mesh,
input_dims_mapping)
var_name
)
key = _convert_pm_and_dm_to_str(
process_mesh, input_dims_mapping
)
has_used_var = var_name + key
var = dist_op.get_serial_input(var_name)
# Not used
......@@ -364,13 +395,16 @@ class CostEstimator:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name][key]["memory"]
var_name
][key]["memory"]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name)
key = _convert_pm_and_dm_to_str(process_mesh,
output_dims_mapping)
var_name
)
key = _convert_pm_and_dm_to_str(
process_mesh, output_dims_mapping
)
has_used_var = var_name + key
var = dist_op.get_serial_output(var_name)
# Not used
......@@ -390,7 +424,8 @@ class CostEstimator:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name][key]["memory"]
var_name
][key]["memory"]
# Calc peak memory
for process in memories:
......@@ -414,8 +449,12 @@ class CostEstimator:
def estimate(self, dist_context, resharder=None):
self.prepare()
from ..reshard import Resharder
resharder = Resharder(self.program, None, self.rank, dist_context,
[]) if resharder is None else resharder
resharder = (
Resharder(self.program, None, self.rank, dist_context, [])
if resharder is None
else resharder
)
block = self.program.global_block()
self._estimate_core(dist_context, resharder, block)
......@@ -447,7 +486,7 @@ class CostEstimator:
memories = [
int(item // 1e6) for item in list(self.max_memories.values())
]
for memory in (memories + header):
for memory in memories + header:
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
......@@ -477,7 +516,7 @@ class CostEstimator:
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header):
for memory in vals + header:
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
......@@ -507,50 +546,73 @@ class CostEstimator:
def get_cost_from_engine(engine, mode):
from ..utils import to_list
# Construct cost estimator by original main program
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
import copy
serial_startup_prog = engine._serial_startup_progs[mode].clone(
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone(
# Construct cost estimator by original main program
serial_main_prog = (
engine._fwd_main_progs[mode].clone()
if mode in engine._fwd_main_progs
else engine._orig_main_prog.clone()
)
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
serial_startup_prog = (
engine._serial_startup_progs[mode].clone()
if mode in engine._serial_startup_progs
else engine._orig_startup_prog.clone()
)
losses = (
to_list(engine._loss)
if (
not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)
)
else engine._losses
)
serial_optimizer = copy.deepcopy(engine._orig_optimizer)
if mode in engine._fwd_dist_contexts:
dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
else:
from ..completion import Completer
from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {},
{"loss": losses}, engine._cluster,
engine._strategy)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program)
dist_context = DistributedContext(
serial_main_prog,
serial_startup_prog,
serial_optimizer,
losses,
{},
{"loss": losses},
engine._cluster,
engine._strategy,
)
from ..completion import Completer
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train":
from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog,
serial_startup_prog,
serial_loss)
params_grads = parallelizer._generate_backward(
serial_main_prog, serial_startup_prog, serial_loss
)
# Generate optimizer
optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer,
params_grads)
serial_main_prog,
serial_startup_prog,
serial_optimizer,
params_grads,
)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory
......
......@@ -13,8 +13,10 @@
# limitations under the License.
import os
import copy
import logging
import random
import numbers
import numpy as np
from collections import defaultdict
......@@ -41,16 +43,20 @@ from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import DistributedDataLoaderFromGenerator, DistributedDataLoader
from .utils import to_list, get_dist_attr, get_lr
from .dist_loader import (
DistributedDataLoaderFromGenerator,
DistributedDataLoader,
)
from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy
from .interface import CollectionNames, get_collection
from ..utils.log_utils import get_logger
from .utils import initialize_pg_in_full_mode
from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode, get_input_split_info
from .cost.estimate_cost import get_cost_from_engine
from ..utils.log_utils import get_logger
class Engine:
"""
......@@ -115,35 +121,55 @@ class Engine:
"""
def __init__(self,
model=None,
loss=None,
optimizer=None,
metrics=None,
cluster=None,
strategy=None):
if model and not isinstance(model,
paddle.nn.Layer) and not callable(model):
def __init__(
self,
model=None,
loss=None,
optimizer=None,
metrics=None,
cluster=None,
strategy=None,
):
if (
model
and not isinstance(model, paddle.nn.Layer)
and not callable(model)
):
raise TypeError(
"'model must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._model = model
if (
loss
and not isinstance(loss, (paddle.nn.Layer, Variable))
and not callable(loss)
):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
)
self._loss = loss
if optimizer and not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`.")
self._optimizer = self._validate_opt(optimizer)
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = validate_opt(optimizer)
self._orig_optimizer = copy.deepcopy(self._optimizer)
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
if metric and not isinstance(metric, Metric):
raise TypeError(
"{} is not sub class of Metric".format(
metric.__class__.__name__
)
)
self._metrics = to_list(metrics)
if cluster and not isinstance(cluster, Cluster):
......@@ -158,9 +184,11 @@ class Engine:
)
self._strategy = strategy or Strategy()
self._logger = get_logger(logging.INFO)
if os.getenv("POD_NAME"):
print("Distribute training by paddle.distributed.launch",
flush=True)
self._logger.info(
"Distribute training by paddle.distributed.launch"
)
fleet.init(is_collective=True)
self._executor = None
......@@ -168,12 +196,12 @@ class Engine:
self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._fwd_main_progs = {}
self._fwd_dist_contexts = {}
self._serial_main_progs = {}
self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs
......@@ -185,19 +213,20 @@ class Engine:
self._has_prepared_reader = {
"train": False,
"eval": False,
"predict": False
"predict": False,
}
self._inputs_spec = []
self._labels_spec = []
self._inputs = []
self._labels = []
self._losses = []
self._mode = None
self._skip_build = False
self._outside_dataloader = False
self._planned_mode = None
self._dygraph_mode = False
self._tuning = self._strategy.tuning
self._losses = None
self.history = None
......@@ -219,9 +248,11 @@ class Engine:
inputs = sample[:split]
labels = sample[split:]
else:
raise ValueError(
"Data should be a Dataset or IterableDatset, but received {}.".
format(type(data).__name__))
raise TypeError(
"Data should be a Dataset or IterableDatset, but received {}.".format(
type(data).__name__
)
)
inputs = to_list(inputs)
labels = to_list(labels)
......@@ -240,14 +271,20 @@ class Engine:
else:
specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
_adjust_item_spec(num_shards, spec)
spec = InputSpec.from_tensor(item, name)
_adjust_item_spec(num_shards, spec)
if batch_size is None:
specs.append(spec)
else:
specs.append(spec.batch(batch_size))
else:
elif isinstance(item, numbers.Number):
specs.append(InputSpec([batch_size], type(item), name))
else:
raise TypeError(
"The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
type(item).__name__
)
)
if inputs is not None:
for i, item in enumerate(inputs):
......@@ -264,37 +301,41 @@ class Engine:
labels_spec = self._validate_spec(labels_spec)
return inputs_spec, labels_spec
def _prepare_data_tensor(self,
inputs_spec,
labels_spec,
inputs=None,
labels=None):
def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
if _non_static_mode() or self._dygraph_mode:
return None, None
inputs_spec = inputs_spec if inputs_spec else []
labels_spec = labels_spec if labels_spec else []
raise ValueError("Only support static graph mode.")
if inputs_spec:
assert isinstance(inputs_spec, list), \
"inputs should be list, but received {}".format(type(inputs_spec))
if inputs is None:
inputs = [s._create_feed_layer() for s in inputs_spec]
else:
assert isinstance(inputs, list), \
"inputs should be list, but received {}".format(type(inputs))
for input_spec, input in zip(inputs_spec, inputs):
if input_spec.shape != input.shape:
input.desc.set_shape(input_spec.shape)
assert isinstance(
inputs_spec, list
), "inputs should be list, but received {}".format(
type(inputs_spec)
)
assert isinstance(
inputs, list
), "inputs should be list, but received {}".format(type(inputs))
assert len(inputs_spec) == len(
inputs
), "the number of `inputs_spec` should be equal to `inputs`'s."
for input_spec, input in zip(inputs_spec, inputs):
if input_spec.shape != input.shape:
input.desc.set_shape(input_spec.shape)
if labels_spec:
assert isinstance(labels_spec, list), \
"labels should be list, but received {}".format(type(labels_spec))
if labels is None:
labels = [s._create_feed_layer() for s in labels_spec]
else:
assert isinstance(labels, list), \
"labels should be list, but received {}".format(type(labels))
for label_spec, label in zip(labels_spec, labels):
if label_spec.shape != label.shape:
label.desc.set_shape(label_spec.shape)
assert isinstance(
labels_spec, list
), "labels should be list, but received {}".format(
type(labels_spec)
)
assert isinstance(
labels, list
), "labels should be list, but received {}".format(type(labels))
assert len(labels_spec) == len(
labels
), "the number of `labels_spec` should be equal to `labels`'s."
for label_spec, label in zip(labels_spec, labels):
if label_spec.shape != label.shape:
label.desc.set_shape(label_spec.shape)
return inputs, labels
def _prepare_reader(self):
......@@ -304,7 +345,9 @@ class Engine:
# NOTE: this list may be changed if Paddle changes the existing rules.
related_reader_ops = [
"create_py_reader", "create_double_buffer_reader", "read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]
# remove the first three ops if multiple run fit/evaluate/predict
if dist_main_block.ops[0].type == 'create_py_reader':
......@@ -322,9 +365,9 @@ class Engine:
for idx in reversed(reader_op_indices):
new_op_desc = dist_main_block.desc._prepend_op()
new_op_desc.copy_from(dist_main_block.ops[idx].desc)
new_op = Operator(dist_main_block,
new_op_desc,
type=new_op_desc.type())
new_op = Operator(
dist_main_block, new_op_desc, type=new_op_desc.type()
)
new_reader_ops.append(new_op)
dist_op = DistributedOperator(new_op)
dist_context.add_dist_op_for_program(dist_op)
......@@ -355,16 +398,22 @@ class Engine:
else:
raise ValueError("Unsupported data {}".format(data))
if user_feeds is not None:
assert isinstance(user_feeds, dict), \
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
assert isinstance(
user_feeds, dict
), "user_feeds must be a dict, but receive {}".format(
type(user_feeds).__name__
)
for name, data in user_feeds.items():
feeds[name] = data
return feeds
def _prepare_fetch(self, user_fetches, mode):
if user_fetches is not None:
assert isinstance(user_fetches, list), \
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
assert isinstance(
user_fetches, list
), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__
)
fetch_names = []
fetch_indices = []
......@@ -396,14 +445,16 @@ class Engine:
_process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices
def _prepare_logger(self,
outs,
epoch=None,
step=None,
lr=None,
fetch_names=None,
fetch_indices=None,
mode=None):
def _prepare_logger(
self,
outs,
epoch=None,
step=None,
lr=None,
fetch_names=None,
fetch_indices=None,
mode=None,
):
logs = {}
if epoch is not None:
logs["epoch"] = epoch
......@@ -468,11 +519,13 @@ class Engine:
self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.")
inputs_spec = self._inputs_spec
labels_spec = self._labels_spec if self._labels_spec else []
self.program_helper = ProgramHelper(self._model, self._loss,
self._metrics, inputs_spec,
labels_spec)
self.program_helper = ProgramHelper(
self._model,
self._loss,
self._metrics,
self._inputs_spec,
self._labels_spec,
)
# build forward main program
self.program_helper.build_program(mode)
......@@ -480,16 +533,12 @@ class Engine:
serial_main_prog = self.program_helper.main_program
serial_startup_prog = self.program_helper.startup_program
inputs = self.program_helper.input_vars
self._inputs = self.program_helper.input_vars
self._labels = self.program_helper.label_vars
outputs = self.program_helper.output_vars
labels = self.program_helper.label_vars
losses = self.program_helper.loss_vars
self._losses = losses
self._losses = self.program_helper.loss_vars
metrics = self.program_helper.metric_vars
self._inputs = inputs
self._labels = labels
paddle.enable_static()
else:
# build program in static mode
......@@ -498,27 +547,45 @@ class Engine:
return
outputs = []
losses = []
metrics = []
inputs = self._inputs if self._inputs else []
labels = self._labels if self._labels else []
self._losses = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
if not self._skip_build:
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
outputs = to_list(self._model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
self._losses = losses
with static.program_guard(
serial_main_prog, serial_startup_prog
), utils.unique_name.guard():
self._inputs = [
s._create_feed_layer() for s in self._inputs_spec
]
self._labels = [
s._create_feed_layer() for s in self._labels_spec
]
outputs = to_list(self._model(*self._inputs))
if mode != "predict" and (outputs or labels):
if mode != "predict" and self._loss:
assert isinstance(
self._loss, paddle.nn.Layer
) or callable(
self._loss
), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
self._losses = to_list(
self._loss(*(outputs + self._labels))
)
if mode != "predict" and (outputs or self._labels):
for metric in self._metrics:
metrics.append(
to_list(metric.compute(*(outputs + labels))))
to_list(
metric.compute(*(outputs + self._labels))
)
)
else:
losses = to_list(self._loss)
self.losses = losses
assert isinstance(
self._loss, Variable
), "the type of `loss` of the Engine arguments should be Variable."
self._losses = to_list(self._loss)
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation:
......@@ -527,12 +594,12 @@ class Engine:
new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True
feed_vars = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": self._inputs, "labels": self._labels}
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
"loss": self._losses,
"metrics": metrics,
}
if mode != "train":
......@@ -540,9 +607,27 @@ class Engine:
self._set_recompute_ckpts()
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self._cluster, self._strategy)
serial_main_prog,
serial_startup_prog,
self._optimizer,
self._losses,
feed_vars,
fetch_vars,
self._cluster,
self._strategy,
)
self._fwd_dist_contexts[mode] = DistributedContext(
serial_main_prog,
serial_startup_prog,
self._optimizer,
self._losses,
feed_vars,
fetch_vars,
self._cluster,
self._strategy,
)
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._fwd_main_progs[mode] = serial_main_prog.clone()
def _optimization_tuning(self, mode, dataset, batch_size):
if not self._tuning.enable:
......@@ -558,20 +643,24 @@ class Engine:
dataset.dp_rank = self._dp_ranks
from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
self._dist_contexts[mode],
dataset,
self._inputs_spec,
self._labels_spec,
batch_size=batch_size,
rank=self._cur_rank)
self._optimization_tuner = OptimizationTuner(
self._tuning.to_dict(),
self._dist_contexts[mode],
dataset,
self._inputs_spec,
self._labels_spec,
batch_size=batch_size,
rank=self._cur_rank,
)
self._optimization_tuner.tune()
if self._tuning.run_after_tuning:
# update the strategy
self._dist_contexts[
mode]._strategy = self._optimization_tuner.get_best_config()
mode
]._strategy = self._optimization_tuner.get_best_config()
def _plan(self, mode):
if self._planned_mode is None:
......@@ -595,8 +684,9 @@ class Engine:
self._dp_world_sizes = []
self._dp_ranks = []
for feed_var in feed_list:
dp_world_size, dp_rank = self._get_input_split_info(
feed_var, self._dist_contexts[mode])
dp_world_size, dp_rank = get_input_split_info(
self._cur_rank, feed_var, self._dist_contexts[mode]
)
self._dp_world_sizes.append(dp_world_size)
self._dp_ranks.append(dp_rank)
......@@ -604,8 +694,9 @@ class Engine:
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, self._planners[mode].completer,
self._dist_contexts[mode])
parallelizer = Parallelizer(
mode, self._planners[mode].completer, self._dist_contexts[mode]
)
if not all_ranks:
parallelizer.parallel(self._cur_rank)
else:
......@@ -623,22 +714,30 @@ class Engine:
for ib, block in enumerate(origin_main_prog.blocks):
for iop, op in enumerate(block.ops):
ref_op = ref_blocks[ib].ops[iop]
assert op.type == ref_op.type, \
"'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
ref_op)
assert (
op.type == ref_op.type
), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
mode, op.type, ref_mode, ref_op.type
)
ref_op_dist_attr = (
ref_dist_context.get_op_dist_attr_for_program(ref_op)
)
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode):
# Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program
mode
].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode].serial_startup_program
mode
].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[
mode].dist_main_programs
mode
].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[
mode].dist_startup_programs
mode
].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode]._serial_optimizer
......@@ -684,30 +783,33 @@ class Engine:
self._executor.run(prune_startup_prog)
if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
self._set_state_dict(mode, self._strict, self._state_dict,
self._dist_attr)
self._set_state_dict(
mode, self._strict, self._state_dict, self._dist_attr
)
if self._strategy.reinit:
self._logger.info("NOTE: parameters will be re-initialized.")
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
self._executor.run(dist_startup_prog)
def fit(self,
train_data,
train_sample_split=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
log_freq=10,
save_dir=None,
save_freq=1,
valid_data=None,
valid_sample_split=None,
valid_freq=1,
valid_steps=None,
collate_fn=None,
callbacks=None,
verbose=2):
def fit(
self,
train_data,
train_sample_split=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
log_freq=10,
save_dir=None,
save_freq=1,
valid_data=None,
valid_sample_split=None,
valid_freq=1,
valid_steps=None,
collate_fn=None,
callbacks=None,
verbose=2,
):
"""
Trains the model for a fixed number of epochs. If `valid_data` is set,
evaluation will be done at the end of each epoch.
......@@ -776,17 +878,13 @@ class Engine:
"""
self._mode = 'train'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size)
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
train_data, train_sample_split, batch_size
)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
assert self._mode in self._dist_main_progs, \
"train model is not ready, please call `engine._prepare_program('train')` first."
train_dataloader = self._prepare_dataloader_from_generator(
dataset=train_data,
capacity=70,
......@@ -794,7 +892,8 @@ class Engine:
batch_size=batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn)
collate_fn=collate_fn,
)
fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
......@@ -823,21 +922,35 @@ class Engine:
self.main_program,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
return_numpy=self._strategy.return_numpy,
)
except core.EOFException:
break
lr = get_lr(self._optimizer)
logs = self._prepare_logger(outs, epoch, step, lr, fetch_names,
fetch_indices, self._mode)
logs = self._prepare_logger(
outs,
epoch,
step,
lr,
fetch_names,
fetch_indices,
self._mode,
)
cbks.on_batch_end('train', step, logs)
if valid_data and (epoch + 1) % valid_freq == 0:
val_logs = self.evaluate(valid_data, valid_sample_split,
batch_size, valid_steps, log_freq,
collate_fn, callbacks, verbose)
val_logs = self.evaluate(
valid_data,
valid_sample_split,
batch_size,
valid_steps,
log_freq,
collate_fn,
callbacks,
verbose,
)
val_logs = {
"val_" + name: val
for name, val in val_logs.items()
"val_" + name: val for name, val in val_logs.items()
}
logs.update(val_logs)
self._switch_mode("train")
......@@ -849,15 +962,17 @@ class Engine:
cbks.on_end('train', logs)
return self.history
def evaluate(self,
valid_data,
valid_sample_split=None,
batch_size=1,
steps=None,
log_freq=10,
collate_fn=None,
callbacks=None,
verbose=2):
def evaluate(
self,
valid_data,
valid_sample_split=None,
batch_size=1,
steps=None,
log_freq=10,
collate_fn=None,
callbacks=None,
verbose=2,
):
"""
Evaluate the loss and metrics of the model on evaluation data.
......@@ -906,23 +1021,21 @@ class Engine:
"""
self._mode = 'eval'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size)
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
valid_data, valid_sample_split, batch_size
)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
assert self._mode in self._dist_main_progs, \
"eval model is not ready, please call `engine._prepare_program('eval')` first."
valid_dataloader = self._prepare_dataloader_from_generator(
dataset=valid_data,
capacity=70,
iterable=False,
batch_size=batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
collate_fn=collate_fn,
)
fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
......@@ -936,10 +1049,9 @@ class Engine:
)
eval_steps = valid_dataloader._steps
cbks.on_begin('eval', {
'steps': eval_steps,
'metrics': self._metrics_name()
})
cbks.on_begin(
'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
)
logs = {}
for step, _ in enumerate(valid_dataloader):
cbks.on_batch_begin('eval', step, logs)
......@@ -948,24 +1060,28 @@ class Engine:
self.main_program,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
return_numpy=self._strategy.return_numpy,
)
except core.EOFException:
break
logs = self._prepare_logger(outs, None, step, None, fetch_names,
fetch_indices, self._mode)
logs = self._prepare_logger(
outs, None, step, None, fetch_names, fetch_indices, self._mode
)
cbks.on_batch_end('eval', step, logs)
cbks.on_end('eval', logs)
self._reset_metrics()
return logs
def predict(self,
test_data,
test_sample_split=None,
batch_size=1,
steps=None,
collate_fn=None,
callbacks=None,
verbose=2):
def predict(
self,
test_data,
test_sample_split=None,
batch_size=1,
steps=None,
collate_fn=None,
callbacks=None,
verbose=2,
):
"""
Compute the output predictions on testing data.
......@@ -1011,24 +1127,21 @@ class Engine:
"""
self._mode = 'predict'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size)
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
test_data, test_sample_split, batch_size
)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
assert self._mode in self._dist_main_progs, \
"predict model is not ready, please call `engine._prepare_program('predict')` first."
test_dataloader = self._prepare_dataloader_from_generator(
dataset=test_data,
capacity=70,
iterable=False,
batch_size=batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
collate_fn=collate_fn,
)
fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
......@@ -1044,41 +1157,45 @@ class Engine:
self.main_program,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
return_numpy=self._strategy.return_numpy,
)
except core.EOFException:
break
logs = self._prepare_logger(outs, None, step, None, fetch_names,
fetch_indices, self._mode)
logs = self._prepare_logger(
outs, None, step, None, fetch_names, fetch_indices, self._mode
)
cbks.on_batch_end('predict', step, logs)
outputs.append(list(logs["outputs"].values()))
cbks.on_end('predict', logs)
return outputs
def dataloader(self,
dataset,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
sample_split=1,
mode=None):
def dataloader(
self,
dataset,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
sample_split=1,
mode=None,
):
if mode is not None:
self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size)
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
dataset, sample_split, batch_size
)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
dataloader = self._prepare_dataloader(
dataset,
return_list=False,
......@@ -1092,32 +1209,35 @@ class Engine:
timeout=timeout,
worker_init_fn=worker_init_fn,
epochs=epochs,
steps_per_epoch=steps_per_epoch)
steps_per_epoch=steps_per_epoch,
)
return dataloader
def dataloader_from_generator(self,
dataset,
capacity=70,
use_double_buffer=True,
iterable=True,
use_multiprocess=False,
drop_last=True,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
sample_split=1,
mode=None):
def dataloader_from_generator(
self,
dataset,
capacity=70,
use_double_buffer=True,
iterable=True,
use_multiprocess=False,
drop_last=True,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
sample_split=1,
mode=None,
):
if mode is not None:
self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size)
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
dataset, sample_split, batch_size
)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
dataloader = self._prepare_dataloader_from_generator(
dataset=dataset,
capacity=capacity,
......@@ -1129,95 +1249,114 @@ class Engine:
batch_size=batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn)
collate_fn=collate_fn,
)
return dataloader
def prepare(self,
inputs_spec=None,
labels_spec=None,
inputs=None,
labels=None,
main_program=None,
startup_program=None,
mode=None):
def prepare(
self,
inputs_spec=None,
labels_spec=None,
inputs=None,
labels=None,
main_program=None,
startup_program=None,
mode=None,
):
if mode is not None:
self.to_mode(mode)
if not self._mode:
raise ValueError(
"Please set mode to be prepared with `prepare(mode=...)`"
)
if self._has_prepared[self._mode]:
return
inputs_spec = self._validate_spec(inputs_spec)
labels_spec = self._validate_spec(labels_spec)
inputs = self._validate_vars(inputs)
labels = self._validate_vars(labels)
self._orig_main_prog = main_program
self._orig_startup_prog = startup_program
if inputs or labels:
self._skip_build = True
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec, inputs, labels)
self._orig_main_prog = main_program
inputs, labels = self._prepare_data_tensor(
inputs_spec, labels_spec, inputs, labels
)
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = startup_program
if self._orig_startup_prog is None:
self._orig_startup_prog = static.default_startup_program()
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
elif inputs_spec or labels_spec:
self._inputs_spec = inputs_spec
self._labels_spec = labels_spec
self._outside_dataloader = True
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
self._orig_main_prog = main_program
if self._orig_main_prog is None:
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = startup_program
if self._orig_startup_prog is None:
self._orig_startup_prog = static.default_startup_program()
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
else:
assert self._inputs_spec and self._labels_spec, \
"Please call the dataloader(...) before calling prepare(...)"
assert (
self._inputs_spec and self._labels_spec
), "Please call the dataloader(...) before calling prepare(...)"
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = inputs, labels
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
def run(self, data=None, feed=None, fetch_list=None, mode=None):
if mode is not None:
self.to_mode(mode)
feed_dict = self._prepare_feed(data, feed, self._mode)
fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
if self._outside_dataloader and not self._has_prepared_reader[
self._mode]:
if (
self._outside_dataloader
and not self._has_prepared_reader[self._mode]
):
self._prepare_reader()
outs = self._executor.run(self.main_program,
feed=feed_dict,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
logs = self._prepare_logger(outs, None, None, None, fetch_names,
fetch_indices, self._mode)
outs = self._executor.run(
self.main_program,
feed=feed_dict,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy,
)
logs = self._prepare_logger(
outs, None, None, None, fetch_names, fetch_indices, self._mode
)
return logs
def _prepare_dataloader(self,
dataset,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None):
def _prepare_dataloader(
self,
dataset,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
......@@ -1256,31 +1395,36 @@ class Engine:
steps_per_epoch=steps_per_epoch,
split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks)
data_parallel_rank=self._dp_ranks,
)
return dataloader
def _prepare_dataloader_from_generator(self,
dataset,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None):
def _prepare_dataloader_from_generator(
self,
dataset,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
......@@ -1316,16 +1460,16 @@ class Engine:
collate_fn=collate_fn,
split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks)
data_parallel_rank=self._dp_ranks,
)
self._prepare_reader()
return dataloader
def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self._mode = 'train'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
tune_data, tune_sample_split, batch_size)
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
tune_data, tune_sample_split, batch_size
)
self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_spec(self, specs):
......@@ -1333,46 +1477,39 @@ class Engine:
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None:
for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec)
if not isinstance(spec, InputSpec):
raise TypeError(
"'spec' must be object of class `paddle.static.InputSpec`."
)
if spec.name is None:
raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec))
"Requires Input[{}].name != None, but receive `None` with {}.".format(
i, spec
)
)
if self._k_steps > 1:
shape = list(spec.shape)
assert shape[0] % self._k_steps == 0, \
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps)
assert (
shape[0] % self._k_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps
)
shape[0] //= self._k_steps
spec.shape = shape
return specs
return specs or []
def _validate_vars(self, vars):
vars = to_list(vars)
if vars is not None:
for i, var in enumerate(vars):
if not isinstance(var, Variable):
raise TypeError("'var' must be a `Variable`.")
return vars or []
def _is_local_var(self, var):
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _get_input_split_info(self, var, dist_context):
# deduce how the input data is split among the cluster
from .utils import _get_comm_group, _get_corresponding_rank
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if self._cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh,
self._cur_rank)
else:
rank_id = self._cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
return len(group_ranks), group_ranks.index(rank_id)
return 1, 0
def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
......@@ -1381,10 +1518,12 @@ class Engine:
# extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer):
if hasattr(self._model,
"gpt") and self._model.__class__.__name__ in [
'GPTForPretraining', 'GPTForPretrainingAuto'
]:
if hasattr(
self._model, "gpt"
) and self._model.__class__.__name__ in [
'GPTForPretraining',
'GPTForPretrainingAuto',
]:
exact_ckpts = self._model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
......@@ -1396,16 +1535,10 @@ class Engine:
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': self._model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts
'Applied Recompute ckpts': exact_ckpts,
}
self._logger.info(logs)
def _validate_opt(self, optimizer):
if optimizer is not None:
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
def _reset_metrics(self):
for metric in self._metrics:
metric.reset()
......@@ -1417,12 +1550,18 @@ class Engine:
return metrics_name
def _switch_mode(self, mode):
assert (
mode in self._dist_main_progs
), "{} model is not ready, please call `prepare()` first.".format(mode)
self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def to_mode(self, mode):
assert mode in ["train", "eval", "predict"], \
"mode {} should be one of ['train', 'eval', 'predict']".format(mode)
assert mode in [
"train",
"eval",
"predict",
], "mode {} should be one of ['train', 'eval', 'predict']".format(mode)
self._mode = mode
def _set_state_dict(self, mode, strict, state_dict, dist_attr):
......@@ -1483,20 +1622,24 @@ class Engine:
serial_program = self._serial_main_progs[self._mode]
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
self._saver.save(path,
serial_program=serial_program,
dist_main_program=dist_main_prog,
dist_context=dist_context)
self._saver.save(
path,
serial_program=serial_program,
dist_main_program=dist_main_prog,
dist_context=dist_context,
)
else:
assert "predict" in self._dist_main_progs
feed_vars = self._feed_vars["predict"]['inputs']
fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
self._saver.save_inference_model(path,
feed_vars,
fetch_vars,
self._executor,
program=dist_main_prog)
self._saver.save_inference_model(
path,
feed_vars,
fetch_vars,
self._executor,
program=dist_main_prog,
)
def load(self, path, strict=True, load_optimizer=True):
"""
......@@ -1508,10 +1651,10 @@ class Engine:
strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape). Default: False.
mismatch shape). Default: True.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False.
from scratch. Default: True.
Returns:
None
......@@ -1546,10 +1689,11 @@ class Engine:
"""
self._strict = strict
self._state_dict, self._dist_attr = self._saver.load(
path, load_optimizer)
path, load_optimizer
)
return self._state_dict, self._dist_attr
def cost(self, inputs_spec=None, labels_spec=None, mode="train"):
def cost(self, inputs_spec=None, labels_spec=None, mode=None):
"""
Get and Print cost, including memory of every rank,
max memory among all ranks, and the global cost of one step based on
......@@ -1560,7 +1704,7 @@ class Engine:
Args:
inputs_spec(InputSpec): The specification of inputs. Default: None.
labels_spec(InputSpec): The specification of labels. Default: None.
mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: "train".
mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
Returns:
Return the global execution time (ms) and max memory (B).
......@@ -1568,29 +1712,44 @@ class Engine:
"""
# Check parallel mode
if self._strategy.auto_mode == "full":
print(
self._logger.info(
"The cost will be calcudated in the search process when the auto mode is full."
)
return
# Check mode
accepted_modes = ["train", "predict", "eval"]
if mode not in accepted_modes:
raise ValueError("The mode {} is not in accepted modes {}".format(
mode, accepted_modes))
mode = mode if mode is not None else self._mode
assert mode is not None, "Please set mode."
if mode not in self._has_prepared:
raise ValueError(
"The mode {} is not in accepted modes {}".format(
mode, list(self._has_prepared.keys())
)
)
self.to_mode(mode)
if inputs_spec is not None:
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = self._prepare_data_tensor(
self._inputs_spec, self._labels_spec)
if inputs_spec is not None and not self._has_prepared[mode]:
self._inputs_spec = self._validate_spec(inputs_spec)
self._labels_spec = self._validate_spec(labels_spec)
self._build(mode)
self._plan(mode)
else:
if _non_static_mode() or self._dygraph_mode:
raise ValueError(
"Please call `engine._prepare_program('mode')` firstly when in the static graph mode."
"Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`."
)
else:
self._logger.info(
"The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
)
program = paddle.static.default_main_program()
if (
not program.global_block().ops
or not program.global_block().ops
) and not self._has_prepared[mode]:
raise ValueError(
"Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`."
)
# Estimate the exec cost and max memory
global_cost, max_memory = get_cost_from_engine(self, mode)
......
......@@ -23,12 +23,18 @@ from functools import reduce
import paddle.fluid.core as core
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.distributed.auto_parallel.process_group import (
get_all_process_groups,
)
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
from paddle.distributed.auto_parallel.dist_attribute import (
TensorDistributedAttribute,
OperatorDistributedAttribute,
)
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
]
......@@ -39,7 +45,8 @@ def get_logger(log_level, name="auto_parallel"):
logger.setLevel(log_level)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
return logger
......@@ -114,8 +121,11 @@ def verify_shard_spec(shard_spec, tensor_shape, process_mesh):
if not verify_dims_mapping(dims_mapping, process_mesh):
return False
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0:
if (
dims_mapping[i] != -1
and tensor_shape[i] > 0
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0
):
return False
return True
......@@ -141,14 +151,17 @@ def compute_compatible_dims_mapping(dims_mapping_list):
return None
length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list:
assert dims_mapping is not None, \
"Dims mapping must not be None for compatible computation"
assert len(dims_mapping) == length, \
"The length of dims_mapping in list must be same for compatible computation."
assert (
dims_mapping is not None
), "Dims mapping must not be None for compatible computation"
assert (
len(dims_mapping) == length
), "The length of dims_mapping in list must be same for compatible computation."
compatible_result = []
for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings))
list(dim_mappings)
)
if compatible_dim_mapping is None:
return None
compatible_result.append(compatible_dim_mapping)
......@@ -161,7 +174,10 @@ def compute_compatible_process_mesh(process_mesh_list):
return compatible_process_mesh
for process_mesh in process_mesh_list:
if process_mesh is not None:
if compatible_process_mesh is None or compatible_process_mesh == process_mesh:
if (
compatible_process_mesh is None
or compatible_process_mesh == process_mesh
):
compatible_process_mesh = process_mesh
else:
return None
......@@ -201,15 +217,18 @@ def remove_distributed_attr_suffix(name):
def check_distributed_attr_for_program(program, dist_context=None):
from .dist_context import get_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \
"Distributed attributes must be initialized before check."
assert (
dist_context.is_initialized_for_program()
), "Distributed attributes must be initialized before check."
for block in program.blocks:
for tensor in block.vars.values():
dist_tensor = dist_context.get_dist_tensor_for_graph(tensor)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
tensor)
tensor
)
if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()):
return False
for op in block.ops:
......@@ -229,6 +248,7 @@ def print_program_with_dist_attr(program, dist_context=None):
lock.acquire()
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program, flush=True)
......@@ -242,7 +262,7 @@ def print_program_with_dist_attr(program, dist_context=None):
def _get_comm_group(processes, shape, axis, rank):
"""
Given a rank and the processes mesh the rank belongs to,
Given a rank and the processes mesh the rank belongs to,
compute the communication peers of the rank based on the give axis in the mesh.
Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2].
......@@ -256,7 +276,8 @@ def _get_comm_group(processes, shape, axis, rank):
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous
assert rank in processes, "rank [{}] is NOT in processes group {}".format(
rank, processes)
rank, processes
)
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie)
coordinates_in_group = [coordinate[:] for i in range(shape[axis])]
......@@ -276,7 +297,7 @@ def _get_comm_group(processes, shape, axis, rank):
def _get_idx_in_axis(processes, shape, axis, rank):
"""
Given a rank and the processes mesh the rank belongs to,
Given a rank and the processes mesh the rank belongs to,
compute the index of the rank in given axis.
Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3].
......@@ -297,20 +318,20 @@ def _coordinate2linear_idx(mesh_shape, coordinate):
"""
convert a coordinate in multidimensional mesh space into a scala idx in linear space.
it use Row-major order for dimension conversion.
it use Row-major order for dimension conversion.
so it has: [most_significant_dim, ..., least_significant_dim]
assume:
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
linear_idx of a n dimensional coordinate is:
linear_idx of a n dimensional coordinate is:
I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) +
I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) +
I[n-3] * ( S[n-4] * .... S[0]) +
I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) +
I[n-3] * ( S[n-4] * .... S[0]) +
...
I[1] * ( S[0]) +
I[1] * ( S[0]) +
I[0]
"""
......@@ -325,14 +346,19 @@ def _coordinate2linear_idx(mesh_shape, coordinate):
assert len(mesh_shape) == len(
coordinate
), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format(
mesh_shape, coordinate)
mesh_shape, coordinate
)
for i in range(len(mesh_shape)):
assert coordinate[
i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format(
i, coordinate)
assert coordinate[i] < mesh_shape[
i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format(
i, mesh_shape, coordinate)
assert (
coordinate[i] >= 0
), "index in dimension [{}] is least than zero. coordinate: {}".format(
i, coordinate
)
assert (
coordinate[i] < mesh_shape[i]
), "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format(
i, mesh_shape, coordinate
)
base = mesh_shape[-1]
linear_idx = coordinate[-1]
......@@ -350,7 +376,7 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
mapping a linear scala into multidimensional mesh space, return it coordinate in that space.
it is the inverse function of _coordinate2linear_idx.
assume:
assume:
the size of i-th dimension to be: S[i]
the index of j-th dimension is: I[j]
......@@ -365,11 +391,13 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
"""
assert linear_idx >= 0, "linear index [{}] is least than zero".format(
linear_idx)
linear_idx
)
assert linear_idx < np.prod(
mesh_shape
), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format(
mesh_shape, linear_idx)
mesh_shape, linear_idx
)
base = 1
coordinate = [-1] * len(mesh_shape)
......@@ -392,15 +420,17 @@ def _get_corresponding_rank(dist_context, target_mesh, rank):
coordinate = None
for mesh in dist_context.process_meshes:
if rank in mesh.processes and mesh.topology == target_mesh.topology:
coordinate = _linear_idx2coordinate(mesh.topology,
mesh.processes.index(rank))
coordinate = _linear_idx2coordinate(
mesh.topology, mesh.processes.index(rank)
)
break
# assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
# rank)
if coordinate is not None:
return target_mesh.processes[_coordinate2linear_idx(
mesh.topology, coordinate)]
return target_mesh.processes[
_coordinate2linear_idx(mesh.topology, coordinate)
]
else:
return target_mesh.processes[0]
......@@ -412,7 +442,8 @@ def _get_unshard_dist_shape(var, dist_attr):
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
var_shape, mapping)
var_shape, mapping
)
new_shape = []
for idx in range(len(var_shape)):
if var_shape[idx] == -1 or mapping[idx] == -1:
......@@ -425,13 +456,15 @@ def _get_unshard_dist_shape(var, dist_attr):
def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None):
from .dist_context import get_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
for var in dist_main_prog.list_vars():
if var.is_data:
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var)
var
)
inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr)
var.desc.set_shape(inverse_shape)
dim_mapping = tensor_dist_attr.dims_mapping
......@@ -441,62 +474,76 @@ def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None):
def _update_addition_info(addition_info):
""" Update default addition_info with inputs """
"""Update default addition_info with inputs"""
add_info = {"epoch": 0, "batch": 0, "batch_size": 0}
if not addition_info:
return add_info
elif not isinstance(addition_info, dict):
raise TypeError("The type of 'addition_info' should be 'dict', "
"but got '{}'.".format(str(type(addition_info))))
raise TypeError(
"The type of 'addition_info' should be 'dict', "
"but got '{}'.".format(str(type(addition_info)))
)
else:
for item, value in addition_info.items():
if item not in ["epoch", "batch", "batch_size"]:
raise ValueError(
"The key of 'addition_info' should be one of the "
"['epoch', 'batch', 'batch_size'], but got '{}'.".format(
str(item)))
str(item)
)
)
if not isinstance(value, int):
raise ValueError(
"The value of 'addition_info' should be 'int', "
"but got '{}'.".format(str(type(value))))
"but got '{}'.".format(str(type(value)))
)
add_info[item] = value
return add_info
def _check_valid_path(file_path):
""" Validity check of input file path """
"""Validity check of input file path"""
if not file_path:
return file_path
elif isinstance(file_path, list):
for file in file_path:
if not isinstance(file, str):
raise TypeError("The type of file path should be 'str', "
"but got '{}'.".format(str(type(file))))
raise TypeError(
"The type of file path should be 'str', "
"but got '{}'.".format(str(type(file)))
)
if not os.path.exists(file):
raise ValueError(
"The file path '{}' does not exist.".format(file))
"The file path '{}' does not exist.".format(file)
)
return file_path
else:
raise TypeError("The type of file path should be 'list', "
"but got '{}'.".format(str(type(file_path))))
raise TypeError(
"The type of file path should be 'list', "
"but got '{}'.".format(str(type(file_path)))
)
def _check_param_dict(param_dict):
if not param_dict:
raise ValueError("'param_dict' cannot be None.")
elif not isinstance(param_dict, dict):
raise TypeError("The type of 'param_dict' should be 'dict', "
"but got '{}'.".format(str(type(param_dict))))
raise TypeError(
"The type of 'param_dict' should be 'dict', "
"but got '{}'.".format(str(type(param_dict)))
)
else:
for name, value in param_dict.items():
if not isinstance(name, str):
raise TypeError(
"The type of key of 'param_dict' should be 'str', "
"but got '{}'.".format(str(type(name))))
"but got '{}'.".format(str(type(name)))
)
if not isinstance(value, paddle.fluid.LoDTensor):
raise TypeError(
"The type of value of 'param_dict' should be 'LoDTensor', "
"but got '{}'.".format(str(type(value))))
"but got '{}'.".format(str(type(value)))
)
return param_dict
......@@ -504,35 +551,42 @@ def _check_dist_attr(dist_attr):
if not dist_attr:
return dist_attr
elif not isinstance(dist_attr, dict):
raise TypeError("The type of 'dist_attr' should be 'dict', "
"but got '{}'.".format(str(type(dist_attr))))
raise TypeError(
"The type of 'dist_attr' should be 'dict', "
"but got '{}'.".format(str(type(dist_attr)))
)
else:
for name, value in dist_attr.items():
if not isinstance(name, str):
raise TypeError(
"The type of param name of 'dist_attr' should be 'str', "
"but got '{}'.".format(str(type(name))))
"but got '{}'.".format(str(type(name)))
)
if not isinstance(value, dict):
raise TypeError(
"The type of distributed attribute should be 'dict', "
"but got '{}'".format(str(type(value))))
"but got '{}'".format(str(type(value)))
)
attr = ['process_shape', 'process_group', 'dims_mapping']
if list(value.keys()) != attr:
raise ValueError(
"The key of distributed attribute should be "
"'['process_shape', 'process_group', 'dims_mapping']', "
"but got {}.".format(str(value.keys())))
"but got {}.".format(str(value.keys()))
)
return dist_attr
def save_distributed_checkpoint(program,
checkpoint_path,
dist_attr_path,
addition_info=None,
is_integrated=False,
dist_context=None):
"""
Save model parameter state, optimzer state, distributed attribute and
def save_distributed_checkpoint(
program,
checkpoint_path,
dist_attr_path,
addition_info=None,
is_integrated=False,
dist_context=None,
):
"""
Save model parameter state, optimzer state, distributed attribute and
additional information of each rank.
Args:
......@@ -569,11 +623,12 @@ def save_distributed_checkpoint(program,
else:
# TODO: integrate param before save
raise NotImplementedError(
"Integrating parameter has not been implemented.")
"Integrating parameter has not been implemented."
)
def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
"""
"""
Load parameter, optimizer, distributed attribute and addition_info.
Args:
......@@ -583,7 +638,7 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
Returns:
param_dict(dict): parameters' value of all ranks.
dist_attr(dict): parameters' distributed attribute.
addition_info(dict): additional information user saved in last training.
addition_info(dict): additional information user saved in last training.
Notes:
The return, 'addition_info', is belonging to the first file of checkpoint_path by default.
......@@ -591,16 +646,16 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
Examples:
.. code-block:: python
ckpt_path = ['./model_state_rank0.pdmodel',
ckpt_path = ['./model_state_rank0.pdmodel',
'./model_state_rank1.pdmodel']
dist_attr_path = ['./dist_attr_rank0.pdattr',
dist_attr_path = ['./dist_attr_rank0.pdattr',
'./dist_attr_rank1.pdattr']
param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path)
"""
assert _check_valid_path(checkpoint_path), \
"'checkpoint_path' cannot be None."
assert _check_valid_path(dist_attr_path), \
"'dist_attr_path' cannot be None."
assert _check_valid_path(
checkpoint_path
), "'checkpoint_path' cannot be None."
assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None."
state_dict_info = _load_distributed_state_dict(checkpoint_path)
dist_attr = _load_distributed_attribute(dist_attr_path)
......@@ -609,11 +664,10 @@ def load_distributed_checkpoint(checkpoint_path, dist_attr_path):
return param_dict, dist_attr, addition_info
def load_checkpoint_into_program(checkpoint_path,
dist_attr_path,
program,
dist_context=None):
"""
def load_checkpoint_into_program(
checkpoint_path, dist_attr_path, program, dist_context=None
):
"""
Load parameter, optimizer, distributed attribute and addition_info into model.
Args:
......@@ -624,7 +678,7 @@ def load_checkpoint_into_program(checkpoint_path,
Returns:
addition_info(dict): user saved in last train.
Notes:
The return, 'addition_info', is belonging to the first file of checkpoint_path by default.
......@@ -632,19 +686,19 @@ def load_checkpoint_into_program(checkpoint_path,
.. code-block:: python
exe.run(startup_program)
ckpt_path = ['./model_state_rank0.pdmodel',
ckpt_path = ['./model_state_rank0.pdmodel',
'./model_state_rank1.pdmodel']
dist_attr_path = ['./dist_attr_rank0.pdattr',
dist_attr_path = ['./dist_attr_rank0.pdattr',
'./dist_attr_rank1.pdattr']
load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program)
"""
from .dist_context import get_default_distributed_context
assert isinstance(program, paddle.fluid.framework.Program)
assert _check_valid_path(checkpoint_path), \
"'checkpoint_path' cannot be None."
assert _check_valid_path(dist_attr_path), \
"'dist_attr_path' cannot be None."
assert _check_valid_path(
checkpoint_path
), "'checkpoint_path' cannot be None."
assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None."
if dist_context is None:
dist_context = get_default_distributed_context()
all_state_dict_info = _load_distributed_state_dict(checkpoint_path)
......@@ -652,16 +706,16 @@ def load_checkpoint_into_program(checkpoint_path,
all_cur_dist_attr = get_dist_attr(program, dist_context)
all_param_dict = all_state_dict_info["model"]
addition_info = all_state_dict_info["addition_info"]
sliced_param_dict = merge_and_slice_parameter(all_param_dict,
all_pre_dist_attr,
all_cur_dist_attr)
sliced_param_dict = merge_and_slice_parameter(
all_param_dict, all_pre_dist_attr, all_cur_dist_attr
)
load_parameter_into_program(sliced_param_dict, program)
return addition_info
def load_parameter_into_program(param_dict, program):
"""
"""
Load parameters into program.
Args:
......@@ -676,28 +730,31 @@ def load_parameter_into_program(param_dict, program):
def _save_distributed_attribute(program, dist_attr_path, dist_context):
""" Save distributed attribute of all parameters """
"""Save distributed attribute of all parameters"""
# TODO: just save a complete distributed attribute file
rank_id = paddle.distributed.get_rank()
dist_attr_name = os.path.join(dist_attr_path,
"dist_attr_rank{}.pdattr".format(rank_id))
dist_attr_name = os.path.join(
dist_attr_path, "dist_attr_rank{}.pdattr".format(rank_id)
)
dist_attr_dict = {
"model": get_dist_attr(program, dist_context),
"world_size": paddle.distributed.get_world_size()
"world_size": paddle.distributed.get_world_size(),
}
paddle.save(dist_attr_dict, dist_attr_name)
logging.info(
"Already saved distributed attribute to '{}'.".format(dist_attr_path))
"Already saved distributed attribute to '{}'.".format(dist_attr_path)
)
def _load_distributed_attribute(dist_attr_path):
""" Load parameters' distributed attribute from dist_attr_path """
"""Load parameters' distributed attribute from dist_attr_path"""
total_dist_attr = {}
for dist_attr_file in dist_attr_path:
dist_attr = paddle.load(dist_attr_file)
pre_world_size = dist_attr["world_size"]
assert pre_world_size == len(dist_attr_path), \
"The number of 'dist_attr_path' must be equal to the last training world size."
assert pre_world_size == len(
dist_attr_path
), "The number of 'dist_attr_path' must be equal to the last training world size."
for name, attr in dist_attr["model"].items():
if name not in total_dist_attr:
total_dist_attr[name] = attr
......@@ -706,27 +763,29 @@ def _load_distributed_attribute(dist_attr_path):
def _save_distributed_state_dict(program, addition_info, checkpoint_path):
""" Save parameters' state_dict """
"""Save parameters' state_dict"""
rank = paddle.distributed.get_rank()
ckpt_file_name = os.path.join(checkpoint_path,
"model_state_rank{}.pdmodel".format(rank))
ckpt_file_name = os.path.join(
checkpoint_path, "model_state_rank{}.pdmodel".format(rank)
)
state_dict = {
"model": program.state_dict(),
"world_size": paddle.distributed.get_world_size(),
"addition_info": addition_info
"addition_info": addition_info,
}
paddle.save(state_dict, ckpt_file_name)
logging.info("Already saved model to '{}'.".format(checkpoint_path))
def _load_distributed_state_dict(checkpoint_path):
""" Load parameters' state_dict from checkpoint_path """
"""Load parameters' state_dict from checkpoint_path"""
all_state_dict = {}
for idx, ckpt_file in enumerate(checkpoint_path):
state_dict_info = paddle.load(ckpt_file, return_numpy=True)
pre_world_size = state_dict_info["world_size"]
assert pre_world_size == len(checkpoint_path), \
"The number of 'checkpoint_path' must be equal to the last training world size."
assert pre_world_size == len(
checkpoint_path
), "The number of 'checkpoint_path' must be equal to the last training world size."
if idx == 0:
addition_info = state_dict_info["addition_info"]
for name, value in state_dict_info["model"].items():
......@@ -737,13 +796,13 @@ def _load_distributed_state_dict(checkpoint_path):
all_state_dict_info = {
"model": all_state_dict,
"addition_info": addition_info
"addition_info": addition_info,
}
return all_state_dict_info
def get_dist_attr(program, dist_context=None):
"""
"""
Get distributed attribute of current rank.
Args:
......@@ -758,13 +817,14 @@ def get_dist_attr(program, dist_context=None):
for var in program.list_vars():
if is_parameter(var) or is_belong_to_optimizer(var):
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var)
var
)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
dist_attr[var.name] = {
"process_shape": process_mesh.topology,
"process_group": process_mesh.processes,
"dims_mapping": dims_mapping
"dims_mapping": dims_mapping,
}
return dist_attr
......@@ -782,19 +842,26 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
dist_param_dict(dict): parameters' value of current rank.
"""
assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None."
assert isinstance(dist_param_dict, dict), \
"The type of 'dist_param_dict' should be 'dict', but got {}.".format(
str(type(dist_param_dict)))
assert isinstance(
dist_param_dict, dict
), "The type of 'dist_param_dict' should be 'dict', but got {}.".format(
str(type(dist_param_dict))
)
for name, value in dist_param_dict.items():
if not isinstance(name, str):
raise TypeError("The key of 'dist_param_dict' is parameter's name, "
"and its type should be 'str', but got {}.".format(
str(type(name))))
raise TypeError(
"The key of 'dist_param_dict' is parameter's name, "
"and its type should be 'str', but got {}.".format(
str(type(name))
)
)
if not isinstance(value, list) or not all(
isinstance(v, np.ndarray) for v in value):
isinstance(v, np.ndarray) for v in value
):
raise TypeError(
"The value of 'dist_param_dict' is parameter's value of all ranks, "
"and its type should be 'list(numpy.ndarray)'.")
"and its type should be 'list(numpy.ndarray)'."
)
if cur_dist_attr is None:
return {}
......@@ -822,7 +889,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
cur_dims_mapping = cur_attr["dims_mapping"]
if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping:
complete_param = _merge_parameter_with_dist_attr(
pre_param, pre_attr)
pre_param, pre_attr
)
dist_param_dict[var_name] = complete_param
else:
complete_param = pre_param[0]
......@@ -830,7 +898,8 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping:
sliced_param = _slice_parameter_with_dist_attr(
complete_param, cur_attr)
complete_param, cur_attr
)
dist_param_dict[var_name] = sliced_param
for var_name in pre_dist_attr:
......@@ -841,67 +910,81 @@ def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr):
if param_not_in_pre:
warnings.warn(
"Parameters '{}' are not found in last training process.".format(
str(param_not_in_pre)))
str(param_not_in_pre)
)
)
if param_not_in_cur:
warnings.warn(
"Parameters '{}' are not found in current training process.".format(
str(param_not_in_cur)))
str(param_not_in_cur)
)
)
return dist_param_dict
def _merge_parameter_with_dist_attr(param_list, dist_attr):
""" Merge parameter with distributed attribute """
"""Merge parameter with distributed attribute"""
from .reshard import Resharder
dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"]
# get the complete shape of the parameter
complete_shape = Resharder.compute_complete_shape(param_list[0].shape,
process_shape,
dims_mapping)
complete_shape = Resharder.compute_complete_shape(
param_list[0].shape, process_shape, dims_mapping
)
# merge the parameter with dist_attr
partition_param_list = []
merged_partiton = []
for process in process_group:
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group)
process, complete_shape, dims_mapping, process_shape, process_group
)
index = process_group.index(process)
if partition_index not in merged_partiton:
merged_partiton.append(partition_index)
_merge_parameter(partition_param_list, param_list[index],
partition_index, complete_shape)
_merge_parameter(
partition_param_list,
param_list[index],
partition_index,
complete_shape,
)
assert len(partition_param_list) == 1 or not partition_param_list, \
"Fail to merge parameter"
assert (
len(partition_param_list) == 1 or not partition_param_list
), "Fail to merge parameter"
complete_param = partition_param_list[0][0]
return complete_param
def _slice_parameter_with_dist_attr(param, dist_attr):
""" Slice parameter with distributed attribute """
param = np.array(param) if isinstance(param,
paddle.fluid.LoDTensor) else param
"""Slice parameter with distributed attribute"""
param = (
np.array(param) if isinstance(param, paddle.fluid.LoDTensor) else param
)
dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"]
# slice the parameter with dist_attr
partition_index_list = _get_split_indices(param.shape, dims_mapping,
process_shape, process_group)
sliced_param_list = _slice_parameter(param, partition_index_list,
len(partition_index_list))
partition_index_list = _get_split_indices(
param.shape, dims_mapping, process_shape, process_group
)
sliced_param_list = _slice_parameter(
param, partition_index_list, len(partition_index_list)
)
# get the current parameter's index in sliced_param_list
rank_id = paddle.distributed.get_rank()
sliced_param_index = _get_sliced_param_index(rank_id, param.shape,
dims_mapping, process_shape,
process_group)
sliced_param_index = _get_sliced_param_index(
rank_id, param.shape, dims_mapping, process_shape, process_group
)
sliced_param = sliced_param_list[sliced_param_index]
return sliced_param
def _merge_parameter(partition_param_list, param, partition_index,
complete_shape):
def _merge_parameter(
partition_param_list, param, partition_index, complete_shape
):
"""
Merge partitial parameters to a complete one.
......@@ -935,19 +1018,30 @@ def _merge_parameter(partition_param_list, param, partition_index,
else:
i = 0
while i < len(partition_param_list):
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index)
(
concat_axis,
first_order,
new_partition,
) = Resharder.compute_concat_info(
partition_param_list[i][1], partition_index
)
if concat_axis != -1:
if first_order == 0:
new_param = np.concatenate(
(partition_param_list[i][0], param), axis=concat_axis)
(partition_param_list[i][0], param), axis=concat_axis
)
else:
new_param = np.concatenate(
(param, partition_param_list[i][0]), axis=concat_axis)
(param, partition_param_list[i][0]), axis=concat_axis
)
partition_param_list.pop(i)
_merge_parameter(partition_param_list, new_param, new_partition,
complete_shape)
_merge_parameter(
partition_param_list,
new_param,
new_partition,
complete_shape,
)
break
i += 1
......@@ -975,19 +1069,21 @@ def _slice_parameter(complete_param, partition_index_list, length):
"""
sliced_param_list = []
axis = len(complete_param.shape) - length
sliced_param = np.split(complete_param,
partition_index_list[axis],
axis=axis)
sliced_param = np.split(
complete_param, partition_index_list[axis], axis=axis
)
if length == 1:
return sliced_param
for param in sliced_param:
sliced_param_list.extend(
_slice_parameter(param, partition_index_list, length - 1))
_slice_parameter(param, partition_index_list, length - 1)
)
return sliced_param_list
def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_group):
def _get_sliced_param_index(
rank, complete_shape, dims_mapping, process_shape, process_group
):
"""
Get sliced_param's index of current rank in all sliced parameters list.
......@@ -1006,7 +1102,7 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
process_group = [0, 1, 2]
slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3)
# slice_param:
# slice_param:
# [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])]
index = _get_sliced_param_index(rank, complete_shape, dims_mapping
......@@ -1015,10 +1111,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
"""
from .reshard import Resharder
partition_index = Resharder.compute_partition_index(rank, complete_shape,
dims_mapping,
process_shape,
process_group)
partition_index = Resharder.compute_partition_index(
rank, complete_shape, dims_mapping, process_shape, process_group
)
sliced_param_index = 0
for i, shape in enumerate(complete_shape):
if dims_mapping[i] == -1:
......@@ -1033,8 +1128,9 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
return sliced_param_index
def _get_split_indices(complete_shape, dims_mapping, process_shape,
process_group):
def _get_split_indices(
complete_shape, dims_mapping, process_shape, process_group
):
"""
Get split indices of every dimension.
......@@ -1059,15 +1155,20 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
split_indices_list = []
for process in process_group:
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group)
process, complete_shape, dims_mapping, process_shape, process_group
)
if split_indices_list:
for dim in range(len(partition_index)):
split_indices_list[dim].extend(partition_index[dim])
else:
split_indices_list = partition_index
split_indices_list = list(
map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list,
complete_shape))
map(
lambda x, y: list(set(x) - set([y]) - set([0])),
split_indices_list,
complete_shape,
)
)
split_indices_list = [sorted(x) for x in split_indices_list]
return split_indices_list
......@@ -1086,8 +1187,10 @@ def set_grad_var_shape(program, dist_context):
if int(op.attr('op_role')) != int(OpRole.Backward):
continue
if int(block.ops[idx-1].attr('op_role')) == int(OpRole.Forward) or \
int(block.ops[idx-1].attr('op_role')) == 257:
if (
int(block.ops[idx - 1].attr('op_role')) == int(OpRole.Forward)
or int(block.ops[idx - 1].attr('op_role')) == 257
):
appended_grad_times += 1
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
......@@ -1105,61 +1208,102 @@ def set_grad_var_shape(program, dist_context):
continue
if var_name in grad_var_to_var[appended_grad_times]:
forward_var_name = grad_var_to_var[appended_grad_times][
var_name]
var_name
]
else:
forward_var_name = var_name[:var_name.find("@GRAD")]
forward_var_name = var_name[: var_name.find("@GRAD")]
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast",
"fill_any_like"
"c_allreduce_sum",
"c_identity",
"scale",
"cast",
"fill_any_like",
]:
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
elif (
op.type == "matmul_v2_grad"
or op.type == "matmul_grad"
or op.type == "mul_grad"
):
forward_var_name = None
for output_name in op.output_names:
if var_name in op.output(output_name):
assert "@GRAD" in output_name
input_name = output_name[:output_name.find("@GRAD")]
input_name = output_name[: output_name.find("@GRAD")]
assert len(op.input(input_name)) == 1
forward_var_name = op.input(input_name)[0]
assert forward_var_name is not None
need_set_shape_list = [
"reshape2_grad", "softmax_with_cross_entropy_grad",
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "tanh_grad", "slice", "assign",
"matmul_v2_triple_grad", "elementwise_add_triple_grad",
"fill_constant", "sqrt_grad",
"reshape2_grad",
"softmax_with_cross_entropy_grad",
"transpose2_grad",
"softmax_grad",
"cross_entropy_grad2",
"dropout_grad",
"tanh_grad",
"slice",
"assign",
"matmul_v2_triple_grad",
"elementwise_add_triple_grad",
"fill_constant",
"sqrt_grad",
"fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "relu_grad"
"flatten_contiguous_range_grad",
"relu_grad",
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle", "flatten_contiguous_range",
"relu"
"reshape2",
"softmax_with_cross_entropy",
"transpose2",
"softmax",
"cross_entropy2",
"dropout",
"tanh",
["slice_grad", "c_allgather"],
"assign",
"matmul_v2_grad_grad",
"elementwise_add_grad_grad",
"shape",
"sqrt",
"fused_softmax_mask_upper_triangle",
"flatten_contiguous_range",
"relu",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
idx = need_set_shape_list.index(op.type)
forward_op_name = forward_list[idx]
if forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names:
op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
if (
forward_op.type in forward_op_name
and forward_var_name in forward_op.input_arg_names
):
op_dist_attr = (
dist_context.get_op_dist_attr_for_program(
forward_op
)
)
break
forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
forward_var_name
)
assert (
forward_input_dist_attr is not None
), f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
forward_var_dist_attr = (
dist_context.get_tensor_dist_attr_for_program(forward_var)
)
assert forward_var_dist_attr is not None
grad_var = vars[var_name]
ref_shape = infer_shape(block, forward_var, forward_var_dist_attr,
forward_input_dist_attr)
ref_shape = infer_shape(
block,
forward_var,
forward_var_dist_attr,
forward_input_dist_attr,
)
if list(grad_var.shape) != ref_shape:
grad_var.desc.set_shape(ref_shape)
......@@ -1171,28 +1315,33 @@ OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.Forward)
or op_role == int(OpRole.Loss))
return OP_ROLE_KEY in op.attr_names and (
op_role == int(OpRole.Forward) or op_role == int(OpRole.Loss)
)
def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY]
) & int(OpRole.Backward)
def is_optimize_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY]
) & int(OpRole.Optimize)
def is_lr_sched_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize.LRSched)
return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY]
) & int(OpRole.Optimize.LRSched)
def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss))
return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY]
) == (int(OpRole.Forward) | int(OpRole.Loss))
def is_loss_grad_op(op):
......@@ -1203,8 +1352,9 @@ def is_loss_grad_op(op):
def is_gradient_clip_op(op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
return op.desc.has_attr("op_namescope") and op.desc.attr(
"op_namescope"
).startswith("/gradient_clip")
def is_prim_op(op):
......@@ -1215,8 +1365,9 @@ def get_loss_op(block):
loss_ops = []
for op in block.ops:
if is_loss_op(op):
assert len(op.desc.output_arg_names()
) == 1, "loss op should only output loss var"
assert (
len(op.desc.output_arg_names()) == 1
), "loss op should only output loss var"
loss_ops.append(op)
assert len(loss_ops) == 1, "num of loss op is not equal to one"
......@@ -1236,7 +1387,8 @@ def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs):
def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_op, process_mesh, ref_mapping, ctx):
new_op, process_mesh, ref_mapping, ctx
):
assert process_mesh is not None
assert ref_mapping is not None
......@@ -1270,9 +1422,11 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
assert (
mapping == -1
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
......@@ -1282,23 +1436,31 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
assert (
mapping == -1
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
assert (
dims_mapping[0] == -1
), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part.".format(
op_desc.type(), mapping
)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
assert (
mapping == -1
), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
assert (
compatible_dim_mapping is not None
), "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
......@@ -1344,8 +1506,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_idx = (
max_dims_mapping_len - input_dims_mapping_lens[arg_name]
) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
......@@ -1357,7 +1520,9 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
assert (
compatible_dims_mapping is not None
), "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
......@@ -1365,55 +1530,64 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op):
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_idx = (
max_dims_mapping_len - input_dims_mapping_lens[arg_name]
) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
op_dist_attr.set_input_dims_mapping(
arg_name, compatible_dims_mapping
)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
op_dist_attr.set_output_dims_mapping(
arg_name, compatible_dims_mapping
)
changed = True
return changed
def get_all_distributed_main_program(serial_program_info, dist_context,
parallelizer):
def get_all_distributed_main_program(
serial_program_info, dist_context, parallelizer
):
"Get all distributed main programs by dist_context."
from .dist_context import DistributedOperatorContext, DistributedContext
cluster = serial_program_info.cluster
copied_parallelizer = copy.deepcopy(parallelizer)
all_dist_main_program = []
ranks = paddle.distributed.get_world_size() if cluster is None else len(
cluster.get_all_devices("GPU"))
ranks = (
paddle.distributed.get_world_size()
if cluster is None
else len(cluster.get_all_devices("GPU"))
)
for rank_id in range(ranks):
used_dist_context = copy.deepcopy(dist_context)
used_dist_context._dist_op_context = DistributedOperatorContext()
_, _, dist_startup_program, dist_main_program, _ = copied_parallelizer._get_dist_program(
rank_id, used_dist_context)
(
_,
_,
dist_startup_program,
dist_main_program,
_,
) = copied_parallelizer._get_dist_program(rank_id, used_dist_context)
all_dist_main_program.append(dist_main_program)
return all_dist_main_program
class SerialProgramInfo:
def __init__(self,
train_program,
satrtup_program,
loss,
optimizer,
cluster=None):
def __init__(
self, train_program, satrtup_program, loss, optimizer, cluster=None
):
self._train_program = train_program
self._startup_program = satrtup_program
self._loss = loss
......@@ -1442,7 +1616,6 @@ class SerialProgramInfo:
def get_standalone_cost_data(distributed_programs):
def _compute_runtime(op_cost, op, vars):
runtime = 0
try:
......@@ -1455,32 +1628,47 @@ def get_standalone_cost_data(distributed_programs):
parsed_info = op_config.split("\n")
variable = "(Variable)"
for info in parsed_info:
variable = "(Variable)" if "(Variable)" in info else "(list<Variable>"
variable = (
"(Variable)" if "(Variable)" in info else "(list<Variable>"
)
if variable in info:
arg_name_lower = info[:info.find(variable) - 1]
arg_name_lower = info[: info.find(variable) - 1]
shape_left_boundary = info.find("[")
shape_right_boundary = info.find("]")
assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed."
shape = info[shape_left_boundary +
1:shape_right_boundary].split(",")
assert (
shape_left_boundary > 0
and shape_right_boundary > 0
and shape_right_boundary > shape_left_boundary
), "Get shape failed."
shape = info[
shape_left_boundary + 1 : shape_right_boundary
].split(",")
shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape)
if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
arg_name_lower = (
"w" if arg_name_lower == "weight" else "ids"
)
for arg_name in op.input_names:
if arg_name.lower() == arg_name_lower:
for var_name in op.input(arg_name):
var = vars[var_name]
total_actual_input_size += reduce(
lambda x, y: x * y, var.shape)
lambda x, y: x * y, var.shape
)
break
assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed."
assert (
total_static_input_size > 0 and total_actual_input_size > 0
), "Get input size failed."
actual_runtime = total_actual_input_size / total_static_input_size * runtime
actual_runtime = (
total_actual_input_size / total_static_input_size * runtime
)
return actual_runtime
import paddle.cost_model as cm
cost_model = cm.CostModel()
cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2
......@@ -1491,13 +1679,16 @@ def get_standalone_cost_data(distributed_programs):
"reshape2": "reshape",
"unsqueeze2": "unsqueeze",
"reduce_sum": "sum",
"elementwise_div": "divide"
"elementwise_div": "divide",
}
standalone_cost_data = []
# skip ops
not_enum_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "assign"
"create_py_reader",
"create_double_buffer_reader",
"read",
"assign",
]
for distributed_program in distributed_programs:
cost_data = {}
......@@ -1507,26 +1698,33 @@ def get_standalone_cost_data(distributed_programs):
if op.type in not_enum_ops:
cost_data[op.desc.id()] = runtime
continue
dtype = str(vars[op.input_arg_names[0]].dtype
) if op.input_arg_names else "float32"
dtype = (
str(vars[op.input_arg_names[0]].dtype)
if op.input_arg_names
else "float32"
)
if int(op.attr('op_role')) == int(OpRole.Backward):
if "_grad" in op.type:
forward_op_name = op.type[:-5]
if forward_op_name in OP_NAME_MAPPING.keys():
forward_op_name = OP_NAME_MAPPING[forward_op_name]
op_cost = cost_model.get_static_op_time(forward_op_name,
forward=False,
dtype=dtype)
op_cost = cost_model.get_static_op_time(
forward_op_name, forward=False, dtype=dtype
)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
else:
op_cost = cost_model.get_static_op_time(forward_op_name,
dtype=dtype)
op_cost = cost_model.get_static_op_time(
forward_op_name, dtype=dtype
)
if op_cost:
runtime = 2 * _compute_runtime(op_cost, op, vars)
elif int(op.attr('op_role')) == int(OpRole.Forward):
op_name = OP_NAME_MAPPING[
op.type] if op.type in OP_NAME_MAPPING.keys() else op.type
op_name = (
OP_NAME_MAPPING[op.type]
if op.type in OP_NAME_MAPPING.keys()
else op.type
)
op_cost = cost_model.get_static_op_time(op_name)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
......@@ -1565,7 +1763,8 @@ def to_list(value):
def debug_program(program, path, name):
filename = os.path.join(
path, name + '_program' + ".%d" % (paddle.distributed.get_rank()))
path, name + '_program' + ".%d" % (paddle.distributed.get_rank())
)
with open(filename, 'w') as f:
f.write(str(program))
......@@ -1599,9 +1798,11 @@ def get_lr(optimizer):
return optimizer._learning_rate()
else:
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
"'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(
type(optimizer)
)
)
def initialize_pg_in_full_mode(all_process_groups, cur_rank):
......@@ -1631,21 +1832,28 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
recv_rank
].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket = socket.socket(
socket.AF_INET, socket.SOCK_STREAM
)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
"Please check comm pair, the recv rank should be {} but got {}.".format(
recv_rank, rank
)
)
else:
print("It is able to instantiate {} as sender now.".format(
process_group.ranks))
print(
"It is able to instantiate {} as sender now.".format(
process_group.ranks
)
)
client_socket.close()
else:
send_rank = process_group.ranks[0]
......@@ -1657,10 +1865,45 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(cur_rank).encode("utf-8"))
str(cur_rank).encode("utf-8")
)
client_sockets[send_rank].close()
print("It is able to instantiate {} as recver now.".
format(process_group.ranks))
print(
"It is able to instantiate {} as recver now.".format(
process_group.ranks
)
)
break
process_group.instantiate()
server_socket.close()
def get_input_split_info(cur_rank, var, dist_context):
# deduce how the input data is split among the cluster
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank)
else:
rank_id = cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(
process_mesh.processes,
process_mesh.topology,
batch_size_axis,
rank_id,
)
return len(group_ranks), group_ranks.index(rank_id)
return 1, 0
def validate_opt(optimizer):
if optimizer is not None:
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
......@@ -20,12 +20,31 @@ from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name
from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op, OP_ROLE_KEY, OpRole
from paddle.distributed.auto_parallel.utils import (
set_var_dist_attr,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
_keep_layer_norm_scale_bias_to_fp32,
_need_keep_fp32,
_valid_types,
_dtype_to_str,
)
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
)
from paddle.distributed.auto_parallel.utils import (
is_forward_op,
is_backward_op,
OP_ROLE_KEY,
OpRole,
)
from .auto_parallel_amp import AMPPass
world_process_group = get_world_process_group()
......@@ -39,11 +58,15 @@ __amp_skip_ops__ = [
def set_op_dtype_to_fp16(op):
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
if (
op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
if (
op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
......@@ -63,7 +86,12 @@ def _keep_fp32_input(op, in_name):
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
'LnScale',
'LnBias',
'Ln2Scale',
'Ln2Bias',
"Ln1Scale",
"Ln1Bias",
}
# backward
if op_type in ['batch_norm_grad']:
......@@ -83,8 +111,12 @@ def _keep_fp32_output(op, out_name):
return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean',
'Ln1Variance'
'LnMean',
'LnVariance',
'Ln2Mean',
'Ln2Variance',
'Ln1Mean',
'Ln1Variance',
}
# backward
if op_type in ['layer_norm_grad']:
......@@ -95,24 +127,28 @@ def _keep_fp32_output(op, out_name):
class FP16State(object):
def __init__(self,
program,
amp_list,
dist_context,
use_fp16_guard,
input_data_var_names=None):
def __init__(
self,
program,
amp_list,
dist_context,
use_fp16_guard,
input_data_var_names=None,
):
self.program = program
self.amp_list = amp_list
self.use_fp16_guard = use_fp16_guard
self.dist_context = dist_context
self.grad_op_to_op_map = self.dist_context.dist_op_context.grad_op_id_to_op_id
self.grad_op_to_op_map = (
self.dist_context.dist_op_context.grad_op_id_to_op_id
)
if input_data_var_names:
self.input_data_var_names = input_data_var_names
else:
self.input_data_var_names = []
self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
self._op_fp16_dict = (
{}
) # op_id --> True/False. 'True' means that the op is should run in fp16 mode.
# a trick to determine leaf tensor node in program {varname: generator_op_id}
self.forward_non_leaf_tensors = {}
# record the cast ops that are inserted for a forward
......@@ -126,7 +162,7 @@ class FP16State(object):
def _build_state(self):
"""
mark the execution mode (fp16 or fp32) for ops in all blocks
mark the execution mode (fp16 or fp32) for ops in all blocks
include forward ops & backward ops
"""
# mark op dtype
......@@ -156,8 +192,9 @@ class FP16State(object):
if op.type == "assign" and "array_" in op.input_arg_names[0]:
self._op_fp16_dict[op.desc.original_id()] = False
return
if _need_keep_fp32(op, self.amp_list.unsupported_list,
self.use_fp16_guard):
if _need_keep_fp32(
op, self.amp_list.unsupported_list, self.use_fp16_guard
):
self._op_fp16_dict[op.desc.original_id()] = False
else:
self._op_fp16_dict[op.desc.original_id()] = True
......@@ -170,8 +207,9 @@ class FP16State(object):
if op.desc.original_id() in self.grad_op_to_op_map:
fwd_op_id = self.grad_op_to_op_map[op.desc.original_id()]
assert fwd_op_id in self._op_fp16_dict, "{}".format(str(op))
self._op_fp16_dict[
op.desc.original_id()] = self._op_fp16_dict[fwd_op_id]
self._op_fp16_dict[op.desc.original_id()] = self._op_fp16_dict[
fwd_op_id
]
if int(op.attr('op_role')) == 257:
self.is_train = True
......@@ -181,7 +219,8 @@ class FP16State(object):
try:
var = block.var(var_name)
except ValueError as e:
var = self.program.global_block().var(var_name)
var = block._var_recursive(var_name)
# var = self.program.global_block().var(var_name)
# NOTE(JZ-LIANG) "array_" is a hack to adopt for ernie3.0 inference, since there is
# a trick which make the LOD_TENSOR_ARRAY to the float32 in while block to reset the LOD_TENSOR_ARRAY
......@@ -196,13 +235,18 @@ class FP16State(object):
for op in block.ops:
if is_forward_op(op):
# NOTE (JZ-LIANG) un-expected cast op when user call "+, -, *, /" in python
if self._is_fp16_op(op.desc.original_id()) == True \
or op.type == "cast":
if (
self._is_fp16_op(op.desc.original_id()) == True
or op.type == "cast"
):
for in_name in op.input_names:
if _keep_fp32_input(op, in_name):
continue
for in_var_name in op.input(in_name):
if in_var_name not in self.forward_non_leaf_tensors and in_var_name not in self.input_data_var_names:
if (
in_var_name not in self.forward_non_leaf_tensors
and in_var_name not in self.input_data_var_names
):
self.set_var_to_fp16(in_var_name, block)
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
......@@ -248,22 +292,42 @@ class FP16State(object):
elif is_forward_op(op):
if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context)
op,
idx,
block,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_forward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context)
op,
idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
self.dist_context,
)
elif is_backward_op(op):
if op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(op.desc.original_id()) == False:
num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, self.dist_context)
op,
idx,
block,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
self.dist_context,
)
elif self._is_fp16_op(op.desc.original_id()) == True:
num_cast_ops = self._insert_backward_cast_ops(
op, idx, block, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, self.dist_context)
op,
idx,
block,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
self.dist_context,
)
elif op.type == "sum":
# all inputs dtype of sum should be equal and output dtype should follow input
out_var_name = op.output_arg_names[0]
......@@ -271,41 +335,51 @@ class FP16State(object):
out_var = block.var(out_var_name)
in_var = block._find_var_recursive(in_var_name)
for in_var_name in op.input_arg_names:
assert in_var.dtype == block.var(
in_var_name).dtype, "{}, {}, {}".format(
in_var, block.var(in_var_name), str(op))
assert (
in_var.dtype == block.var(in_var_name).dtype
), "{}, {}, {}".format(
in_var, block.var(in_var_name), str(op)
)
out_var.desc.set_dtype(in_var.dtype)
idx += num_cast_ops + 1
block._sync_with_cpp()
def _insert_forward_cast_ops(self, op, idx, block, src_dtype, dst_dtype,
dist_context):
def _insert_forward_cast_ops(
self, op, idx, block, src_dtype, dst_dtype, dist_context
):
num_cast_ops = 0
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name):
op, in_name
):
continue
consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
assert consume_op_attr is not None
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var is None or in_var.type not in _valid_types or in_var.dtype == dst_dtype:
if (
in_var is None
or in_var.type not in _valid_types
or in_var.dtype == dst_dtype
):
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(
dst_dtype)
cast_name = (
in_var.name + '.cast_' + _dtype_to_str(dst_dtype)
)
cast_var = block.vars.get(cast_name)
self.forward_input_cast_ops[op.desc.original_id()] += [
(cast_name, in_var.name, dst_dtype, src_dtype, in_name)
]
in_var_dist_attr = consume_op_attr.get_input_dist_attr(
in_var.name)
in_var.name
)
assert in_var_dist_attr is not None
# truly insert cast op
if cast_var is None or cast_var.dtype != dst_dtype:
......@@ -319,9 +393,11 @@ class FP16State(object):
name=cast_name,
dtype=dst_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient)
set_var_dist_attr(dist_context, cast_var, ref_mapping,
ref_mesh)
stop_gradient=in_var.stop_gradient,
)
set_var_dist_attr(
dist_context, cast_var, ref_mapping, ref_mesh
)
cast_op = block._insert_op_without_sync(
idx,
......@@ -331,23 +407,27 @@ class FP16State(object):
attrs={
"in_dtype": in_var.dtype,
"out_dtype": cast_var.dtype,
OP_ROLE_KEY: OpRole.Forward
})
OP_ROLE_KEY: OpRole.Forward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
op._rename_input(in_var.name, cast_name)
consume_op_attr.set_input_dist_attr(cast_name,
in_var_dist_attr)
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr
)
if op.has_attr('out_dtype') and op.attr('out_dtype') != -1:
assert op.attr('out_dtype') == dst_dtype
return num_cast_ops
def _insert_backward_cast_ops(self, op, idx, block, src_dtype, dst_dtype,
dist_context):
def _insert_backward_cast_ops(
self, op, idx, block, src_dtype, dst_dtype, dist_context
):
num_cast_ops = 0
op_id = op.desc.id()
......@@ -363,15 +443,21 @@ class FP16State(object):
if _keep_fp32_output(op, out_var.name):
continue
assert out_var.dtype == dst_dtype, "{}, {}".format(
str(out_var), dst_dtype)
str(out_var), dst_dtype
)
for cast_name, src_name, dst_dtype, src_dtype, slot_name in self.forward_input_cast_ops[
forward_op_id]:
for (
cast_name,
src_name,
dst_dtype,
src_dtype,
slot_name,
) in self.forward_input_cast_ops[forward_op_id]:
# rename input
assert src_name in op.input(
slot_name), "var: {} not in op's {}. {}".format(
src_name, slot_name, str(op))
slot_name
), "var: {} not in op's {}. {}".format(src_name, slot_name, str(op))
src_var_dist_attr = grad_op_attr.get_input_dist_attr(src_name)
assert src_var_dist_attr is not None
op._rename_input(src_name, cast_name)
......@@ -393,15 +479,18 @@ class FP16State(object):
ref_mapping = grad_dist_attr.dims_mapping
cast_grad = block.create_var(
name=unique_name.generate_with_ignorable_key("".join(
[cast_name, '@GRAD'])),
name=unique_name.generate_with_ignorable_key(
"".join([cast_name, '@GRAD'])
),
dtype=dst_dtype,
shape=grad.shape,
type=grad.type,
persistable=grad.persistable,
stop_gradient=grad.stop_gradient)
stop_gradient=grad.stop_gradient,
)
dist_context.set_tensor_dist_attr_for_program(
cast_grad, grad_dist_attr)
cast_grad, grad_dist_attr
)
op._rename_output(grad_name, cast_grad.name)
grad_op_attr.set_output_dist_attr(cast_grad.name, grad_dist_attr)
......@@ -414,12 +503,14 @@ class FP16State(object):
attrs={
"in_dtype": dst_dtype,
"out_dtype": src_dtype,
OP_ROLE_KEY: OpRole.Backward
})
OP_ROLE_KEY: OpRole.Backward,
},
)
grad.desc.set_dtype(src_dtype)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
cast_op, ref_mesh, ref_mapping, dist_context
)
num_cast_ops += 1
return num_cast_ops
......@@ -432,26 +523,34 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64'],
'check_finite_and_unscale',
)
found_inf = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
['find_infinite_scale', name])),
name=unique_name.generate_with_ignorable_key(
".".join(['find_infinite_scale', name])
),
shape=[1],
dtype='bool',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
stop_gradient=False,
)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs)
new_op = main_block.append_op(
type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = world_process_group.ranks
......@@ -461,10 +560,12 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(g.name,
g_dist_attr.dims_mapping)
new_op_dist_attr.set_input_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
new_op_dist_attr.set_output_dims_mapping(
g.name, g_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
return grads, found_inf
......@@ -473,8 +574,9 @@ def _split_grads(params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \
"Data types of all grads must be either fp16 or fp32."
assert len(fp32_grads) + len(fp16_grads) == len(
grads
), "Data types of all grads must be either fp16 or fp32."
return grads, fp32_grads, fp16_grads
......@@ -486,37 +588,45 @@ def _set_op_dist_attr_with_ranks(new_op, ranks, block, dist_context):
var = block.var(var_name)
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
assert var_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping(var_name,
var_dist_attr.dims_mapping)
new_op_dist_attr.set_input_dims_mapping(
var_name, var_dist_attr.dims_mapping
)
for var_name in new_op.output_arg_names:
var = block.var(var_name)
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
assert var_dist_attr is not None
new_op_dist_attr.set_output_dims_mapping(var_name,
var_dist_attr.dims_mapping)
new_op_dist_attr.set_output_dims_mapping(
var_name, var_dist_attr.dims_mapping
)
dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def _get_memcopy_idx(block, found_inf_var):
# use reduce_any op for check_nan_inf as the anchor for now
for idx, op in enumerate(block.ops):
if op.type == 'reduce_any' and op.output_arg_names[
0] == found_inf_var.name:
if (
op.type == 'reduce_any'
and op.output_arg_names[0] == found_inf_var.name
):
return idx + 1
raise RuntimeError(
"not found the correct location for memcopy for found_inf_var.")
"not found the correct location for memcopy for found_inf_var."
)
def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
src_name = src_var.name
output_var = block.create_var(name=unique_name.generate_with_ignorable_key(
src_name.join(['memcopy_'])),
dtype=src_var.dtype,
shape=src_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=src_var.stop_gradient)
output_var = block.create_var(
name=unique_name.generate_with_ignorable_key(
src_name.join(['memcopy_'])
),
dtype=src_var.dtype,
shape=src_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=src_var.stop_gradient,
)
set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks)
......@@ -527,16 +637,20 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
dst_place_type = 1
else:
raise NotImplementedError(
"direction [{}] is not supported yet.".format(direction))
"direction [{}] is not supported yet.".format(direction)
)
attrs = {'dst_place_type': dst_place_type}
new_op = block._insert_op_without_sync(index=idx,
type='memcpy',
inputs={'X': [src_var]},
outputs={'Out': [output_var]},
attrs=attrs)
_set_op_dist_attr_with_ranks(new_op, world_process_group.ranks, block,
dist_context)
new_op = block._insert_op_without_sync(
index=idx,
type='memcpy',
inputs={'X': [src_var]},
outputs={'Out': [output_var]},
attrs=attrs,
)
_set_op_dist_attr_with_ranks(
new_op, world_process_group.ranks, block, dist_context
)
block._sync_with_cpp()
return output_var
......@@ -564,19 +678,21 @@ def cast_startup_program():
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name,
None) == core.VarDesc.VarType.FP16:
if (
param_to_dtype.get(output_name, None)
== core.VarDesc.VarType.FP16
):
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op))
str(op)
)
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
def __init__(self):
super(FP16Pass, self).__init__()
......@@ -589,16 +705,22 @@ class FP16Pass(AMPPass):
amp_list = AutoMixedPrecisionLists(
set(self.get_attr("custom_white_list")),
set(self.get_attr("custom_black_list")), None)
set(self.get_attr("custom_black_list")),
None,
)
# NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass
input_data_var_names = [var.name for var in self.get_attr("input_data")]
with paddle.static.program_guard(main_program, startup_program):
fp16_state = FP16State(main_program, amp_list, self.dist_context,
self.get_attr("use_fp16_guard"),
input_data_var_names)
fp16_state = FP16State(
main_program,
amp_list,
self.dist_context,
self.get_attr("use_fp16_guard"),
input_data_var_names,
)
is_train = fp16_state._build_state()
cast_startup_program()
......@@ -611,41 +733,63 @@ class FP16Pass(AMPPass):
grads, fp32_grads, fp16_grads = _split_grads(params_grads)
if self.get_attr("use_dynamic_loss_scaling"
) or self.get_attr("init_loss_scaling") != 1.0:
if (
self.get_attr("use_dynamic_loss_scaling")
or self.get_attr("init_loss_scaling") != 1.0
):
found_infs = []
if fp32_grads:
with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32",
self.dist_context)
fp32_grads,
self._loss_scaling,
"@fp32",
self.dist_context,
)
found_infs.append(found_inf_fp32)
if fp16_grads:
with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16",
self.dist_context)
fp16_grads,
self._loss_scaling,
"@fp16",
self.dist_context,
)
found_infs.append(found_inf_fp16)
with main_program._optimized_guard([]):
block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs)
set_var_dist_attr(self.dist_context, all_infs, [-1],
world_process_group.ranks)
set_var_dist_attr(
self.dist_context,
all_infs,
[-1],
world_process_group.ranks,
)
new_op = block.ops[-1]
assert new_op.type == "concat"
_set_op_dist_attr_with_ranks(new_op,
world_process_group.ranks,
block, self.dist_context)
_set_op_dist_attr_with_ranks(
new_op,
world_process_group.ranks,
block,
self.dist_context,
)
found_inf = paddle.fluid.layers.reduce_any(all_infs)
set_var_dist_attr(self.dist_context, found_inf, [-1],
world_process_group.ranks)
set_var_dist_attr(
self.dist_context,
found_inf,
[-1],
world_process_group.ranks,
)
new_op = block.ops[-1]
assert new_op.type == "reduce_any"
_set_op_dist_attr_with_ranks(new_op,
world_process_group.ranks,
block, self.dist_context)
_set_op_dist_attr_with_ranks(
new_op,
world_process_group.ranks,
block,
self.dist_context,
)
if self.get_attr("use_dynamic_loss_scaling"):
with main_program._optimized_guard([]):
......@@ -660,14 +804,15 @@ class FP16Pass(AMPPass):
if self.get_attr("use_optimizer_fp16"):
base_opt._multi_precision = False
if isinstance(
base_opt,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)):
base_opt, (paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)
):
with main_program._optimized_guard([]):
# found_inf = paddle.tensor.creation._memcpy(
# found_inf, paddle.CPUPlace())
insert_idx = _get_memcopy_idx(block, found_inf)
found_inf = _insert_memcopy(block, insert_idx, found_inf,
self.dist_context)
found_inf = _insert_memcopy(
block, insert_idx, found_inf, self.dist_context
)
base_opt._set_auxiliary_var('found_inf', found_inf.name)
elif hasattr(base_opt, "_set_auxiliary_var"):
base_opt._set_auxiliary_var('found_inf', found_inf.name)
......@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
......@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
endif()
......@@ -31,7 +31,10 @@ from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.interface import get_collection, CollectionNames
from paddle.distributed.auto_parallel.interface import (
get_collection,
CollectionNames,
)
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
......@@ -56,7 +59,6 @@ my_feed_vars = []
class MyDataset(Dataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
......@@ -77,38 +79,38 @@ def get_random_inputs_and_labels(image_shape, label_shape):
def batch_generator_creator():
def __reader__():
for _ in range(batch_num):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, image_size], [batch_size, 1])
[batch_size, image_size], [batch_size, 1]
)
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
......@@ -132,16 +134,20 @@ class MLPLayer(nn.Layer):
def train_high_level(fetch):
global is_fetch
is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
......@@ -153,11 +159,13 @@ def train_high_level(fetch):
train_dataset = MyDataset(batch_num * batch_size)
eval_dataset1 = MyDataset(5 * batch_size)
history = engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size,
valid_data=eval_dataset1,
log_freq=1)
history = engine.fit(
train_data=train_dataset,
epochs=2,
batch_size=batch_size,
valid_data=eval_dataset1,
log_freq=1,
)
# eval
eval_dataset2 = MyDataset(batch_size)
......@@ -176,16 +184,20 @@ def train_high_level(fetch):
def train_low_level():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
......@@ -200,18 +212,18 @@ def train_low_level():
# Build normal normal dataloader
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size,
mode="train")
train_dataloader = engine.dataloader(
train_dataset, batch_size=batch_size, mode="train"
)
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size,
mode="eval")
eval_dataloader = engine.dataloader(
eval_dataset2, batch_size=batch_size, mode="eval"
)
engine.prepare(mode="eval")
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict, mode="eval")
......@@ -234,9 +246,9 @@ def train_low_level():
# Build dataloader from generator
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader_from_generator(train_dataset,
batch_size=batch_size,
mode="train")
train_dataloader = engine.dataloader_from_generator(
train_dataset, batch_size=batch_size, mode="train"
)
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
......@@ -244,17 +256,18 @@ def train_low_level():
# eval
engine.to_mode("eval")
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader_from_generator(eval_dataset2,
batch_size=batch_size)
eval_dataloader = engine.dataloader_from_generator(
eval_dataset2, batch_size=batch_size
)
engine.prepare()
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict)
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader_from_generator(test_dataset,
batch_size=batch_size,
mode="predict")
predict_dataloader = engine.dataloader_from_generator(
test_dataset, batch_size=batch_size, mode="predict"
)
engine.prepare(mode="predict")
for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict, mode="predict")
......@@ -268,16 +281,20 @@ def train_low_level():
def train_builtin_data_vars():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
......@@ -295,9 +312,9 @@ def train_builtin_data_vars():
with static.program_guard(engine.main_program, engine.startup_program):
feed_list = engine.inputs + engine.labels
print(feed_list)
loader = paddle.io.DataLoader.from_generator(feed_list=feed_list,
capacity=4 * batch_size,
iterable=False)
loader = paddle.io.DataLoader.from_generator(
feed_list=feed_list, capacity=4 * batch_size, iterable=False
)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
......@@ -308,36 +325,40 @@ def train_builtin_data_vars():
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
loader.reset() # call DataLoader.reset() after catching EOFException
def train_non_builtin_data_vars():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
with static.program_guard(
main_program, startup_program
), utils.unique_name.guard():
input = static.data(
name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
loader = paddle.io.DataLoader.from_generator(
feed_list=[input, label], capacity=4 * batch_size, iterable=False
)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
......@@ -345,53 +366,109 @@ def train_non_builtin_data_vars():
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
engine = auto.Engine(
loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
)
# train
engine.to_mode("train")
engine.prepare(inputs=[input],
labels=[label],
main_program=main_program,
startup_program=startup_program)
engine.prepare(
inputs=[input],
labels=[label],
main_program=main_program,
startup_program=startup_program,
)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
loader.reset() # call DataLoader.reset() after catching EOFException
def get_cost():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(
main_program, startup_program
), utils.unique_name.guard():
input = static.data(
name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(
feed_list=[input, label], capacity=4 * batch_size, iterable=False
)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(
loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
)
engine.prepare(
main_program=main_program,
startup_program=startup_program,
inputs=[input],
labels=[label],
mode="train",
)
engine.cost()
def get_cost_by_default_program():
main_program = static.default_main_program()
startup_program = static.default_startup_program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
with static.program_guard(
main_program, startup_program
), utils.unique_name.guard():
input = static.data(
name="input", shape=[batch_size, image_size], dtype='float32'
)
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
loader = paddle.io.DataLoader.from_generator(
feed_list=[input, label], capacity=4 * batch_size, iterable=False
)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
......@@ -399,24 +476,27 @@ def get_cost():
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
engine.cost()
engine = auto.Engine(
loss=loss_var, optimizer=optimizer, metrics=metric, strategy=strategy
)
engine.cost(mode="train")
def get_cost_by_spec():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
......@@ -436,4 +516,5 @@ if __name__ == "__main__":
train_builtin_data_vars()
train_non_builtin_data_vars()
get_cost()
get_cost_by_default_program()
get_cost_by_spec()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static as static
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset
from paddle.distributed.fleet import auto
paddle.enable_static()
epoch_num = 1
batch_size = 2
batch_num = 10
hidden_size = 1024
sequence_len = 512
image_size = hidden_size
class_num = 10
is_fetch = True
is_feed = True
my_feed_vars = []
class TrainDataset(Dataset):
def __init__(self, num_samples):
super(TrainDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
def __len__(self):
return self.num_samples
class TestDataset(Dataset):
def __init__(self, num_samples):
super(TestDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
if is_feed:
my_feed_vars.append((out, out.shape))
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
if is_feed:
my_feed_vars.append((out, out.shape))
if is_fetch:
auto.fetch(out, "my_fetch", logging=True)
return out
class TestEngineErrorRaise(unittest.TestCase):
def setUp(self):
class NoSupportData1:
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
class NoSupportData2(TrainDataset):
def __getitem__(self, index):
input = [
list(np.random.uniform(size=image_size).astype("float32"))
]
label = [np.random.randint(0, class_num - 1, dtype="int64")]
return input, label
class NoSupportData3:
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
return input
class NoSupportData4(TestDataset):
def __getitem__(self, index):
input = [
list(np.random.uniform(size=image_size).astype("float32"))
]
return input
self.no_support_data_1 = NoSupportData1()
self.no_support_data_2 = NoSupportData2(10)
self.no_support_data_3 = NoSupportData3()
self.no_support_data_4 = NoSupportData4(10)
def test_Engine(self):
with self.assertRaises(TypeError):
auto.Engine(model=paddle.static.Program())
with self.assertRaises(TypeError):
auto.Engine(loss="CrossEntropyLoss")
with self.assertRaises(TypeError):
auto.Engine(optimizer="adam")
with self.assertRaises(TypeError):
auto.Engine(metrics=["acc"])
with self.assertRaises(TypeError):
auto.Engine(cluster="cluster")
with self.assertRaises(TypeError):
auto.Engine(strategy="strategy")
def test_fit(self):
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
optimizer=paddle.optimizer.AdamW(0.00001),
)
engine.fit(train_data=self.no_support_data_1)
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
optimizer=paddle.optimizer.AdamW(0.00001),
)
engine.fit(train_data=self.no_support_data_2)
def test_evaluate(self):
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
engine.evaluate(valid_data=self.no_support_data_3)
with self.assertRaises(TypeError):
engine = auto.Engine(
model=MLPLayer(),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(),
)
engine.evaluate(
valid_data=self.no_support_data_4, valid_sample_split=1
)
def test_predict(self):
with self.assertRaises(TypeError):
engine = auto.Engine(model=MLPLayer())
engine.predict(
test_data=self.no_support_data_3, test_sample_split=1
)
with self.assertRaises(TypeError):
engine = auto.Engine(model=MLPLayer())
engine.predict(
test_data=self.no_support_data_4, test_sample_split=1
)
def build_program(self):
main_prog = static.Program()
startup_prog = static.Program()
with static.program_guard(main_prog, startup_prog):
input = static.data(
name="input",
shape=[batch_size // 2, image_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size // 2, 1], dtype='int64'
)
mlp = MLPLayer()
loss = paddle.nn.CrossEntropyLoss()
predict = mlp(input)
loss_var = loss(predict, label)
return main_prog, startup_prog, input, label, loss_var
def test_prepare(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.prepare()
with self.assertRaises(AssertionError):
engine = auto.Engine(model=MLPLayer())
engine.prepare(mode="train")
with self.assertRaises(TypeError):
input = static.data(
name="input",
shape=[batch_size / 2, image_size],
dtype='float32',
)
label = static.data(
name="label", shape=[batch_size / 2, 1], dtype='int64'
)
engine = auto.Engine(model=MLPLayer())
engine.prepare(inputs_spec=input, labels_spec=label, mode="eval")
input_spec = static.InputSpec(
shape=[batch_size, image_size], dtype="float32", name="input"
)
label_spec = static.InputSpec(
shape=[batch_size, image_size], dtype="float32", name="input"
)
(
main_prog,
startup_prog,
input_var,
label_var,
loss_var,
) = self.build_program()
with self.assertRaises(TypeError):
engine = auto.Engine(loss=loss_var)
engine.prepare(
inputs=input_spec,
labels=label_spec,
main_program=main_prog,
startup_program=startup_prog,
mode="eval",
)
with self.assertRaises(AssertionError):
engine = auto.Engine(loss=loss_var)
engine.prepare(
inputs_spec=[input_spec, input_spec],
labels_spec=[label_spec, label_spec],
inputs=input_var,
labels=label_var,
main_program=main_prog,
startup_program=startup_prog,
mode="predict",
)
def test_cost(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.cost(mode="predict")
class TestEngineDynamicErrorRaise(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def tearDown(self):
paddle.enable_static()
def test_cost(self):
with self.assertRaises(ValueError):
engine = auto.Engine(model=MLPLayer())
engine.cost(mode="predict")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册