"Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
"Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
# Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
assertflagin[
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
"Flags_skip_mp_c_identity",
],"Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)"