未验证 提交 4389a804 编写于 作者: C Chang Xu 提交者: GitHub

Add Parallel Linear in ChannelWise Quant (#1739)

上级 8ff2de56
...@@ -19,7 +19,12 @@ from .mse import MSEObserverLayer ...@@ -19,7 +19,12 @@ from .mse import MSEObserverLayer
from .uniform import UniformObserver from .uniform import UniformObserver
from paddle.quantization.factory import ObserverFactory from paddle.quantization.factory import ObserverFactory
CHANNEL_AXIS: Dict[type, int] = {paddle.nn.Conv2D: 0, paddle.nn.Linear: 1} CHANNEL_AXIS: Dict[type, int] = {
paddle.nn.Conv2D: 0,
paddle.nn.Linear: 1,
paddle.distributed.fleet.meta_parallel.ColumnParallelLinear: 1,
paddle.distributed.fleet.meta_parallel.RowParallelLinear: 1
}
class ChannelWiseObserver(UniformObserver): class ChannelWiseObserver(UniformObserver):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册