diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 1a0cb322cf499458323da2ab570d4a0f09e31a78..61ccbbcc448b9e3101537640a2ad05dc2d7e216c 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -450,6 +450,14 @@ class ColumnParallelLinear(paddle.nn.Layer): and mp_configs.mp_async_allreduce and mp_configs.mp_fused_linear_param_grad_add ) + if ( + self.mp_async_allreduce + or self.mp_skip_c_identity + or self.mp_fused_linear_param_grad_add + ): + assert ( + paddle.in_dynamic_mode() + ), "mp_async_allreduce, mp_skip_c_identity and mp_fused_linear_param_grad_add are only available under dygraph mode" if self.fuse_matmul_bias: if not is_fused_matmul_bias_supported(): raise NotImplementedError( @@ -614,6 +622,14 @@ class RowParallelLinear(paddle.nn.Layer): and mp_configs.mp_async_allreduce and mp_configs.mp_fused_linear_param_grad_add ) + if ( + self.mp_async_allreduce + or self.mp_skip_c_identity + or self.mp_fused_linear_param_grad_add + ): + assert ( + paddle.in_dynamic_mode() + ), "mp_async_allreduce, mp_skip_c_identity and mp_fused_linear_param_grad_add are only available under dygraph mode" assert in_features % self.world_size == 0, ( "Number of row of the weight for linear ({}) must be" " divisible by model parallel size ({})".format(