diff --git a/paddleslim/quant/observers/channel_wise.py b/paddleslim/quant/observers/channel_wise.py index 7d270cdd9a7677b0f5e09a3647b0ae23e6b96572..9962af83573e25953af0e24fad5b1eaaa328393b 100644 --- a/paddleslim/quant/observers/channel_wise.py +++ b/paddleslim/quant/observers/channel_wise.py @@ -19,7 +19,12 @@ from .mse import MSEObserverLayer from .uniform import UniformObserver from paddle.quantization.factory import ObserverFactory -CHANNEL_AXIS: Dict[type, int] = {paddle.nn.Conv2D: 0, paddle.nn.Linear: 1} +CHANNEL_AXIS: Dict[type, int] = { + paddle.nn.Conv2D: 0, + paddle.nn.Linear: 1, + paddle.distributed.fleet.meta_parallel.ColumnParallelLinear: 1, + paddle.distributed.fleet.meta_parallel.RowParallelLinear: 1 +} class ChannelWiseObserver(UniformObserver):