提交 497395d3 编写于 作者: N niuliling123

modified for SD Layer

上级 7849d58d
...@@ -171,6 +171,11 @@ class Linear(Layer): ...@@ -171,6 +171,11 @@ class Linear(Layer):
self.name = name self.name = name
def forward(self, input): def forward(self, input):
with paddle.amp.auto_cast(custom_white_list={'elementwise_add','fused_gemm_epilogue'}, dtype='bfloat16'):
out = paddle.incubate.nn.functional.fused_linear(
x=input, weight=self.weight, bias=self.bias, name=self.name
)
return out
out = F.linear( out = F.linear(
x=input, weight=self.weight, bias=self.bias, name=self.name x=input, weight=self.weight, bias=self.bias, name=self.name
) )
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# TODO: define classes of convolutional neural network # TODO: define classes of convolutional neural network
import numpy as np import numpy as np
import paddle
from paddle import get_flags from paddle import get_flags
from ...device import ( from ...device import (
...@@ -704,29 +704,30 @@ class Conv2D(_ConvNd): ...@@ -704,29 +704,30 @@ class Conv2D(_ConvNd):
) )
def forward(self, x): def forward(self, x):
if self._padding_mode != 'zeros': with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}, level='O1', dtype='bfloat16'):
x = F.pad( if self._padding_mode != 'zeros':
x = F.pad(
x,
self._reversed_padding_repeated_twice,
mode=self._padding_mode,
data_format=self._data_format,
)
out = F.conv._conv_nd(
x, x,
self._reversed_padding_repeated_twice, self.weight,
mode=self._padding_mode, bias=self.bias,
stride=self._stride,
padding=self._updated_padding,
padding_algorithm=self._padding_algorithm,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format, data_format=self._data_format,
channel_dim=self._channel_dim,
op_type=self._op_type,
use_cudnn=self._use_cudnn,
) )
return out
out = F.conv._conv_nd(
x,
self.weight,
bias=self.bias,
stride=self._stride,
padding=self._updated_padding,
padding_algorithm=self._padding_algorithm,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format,
channel_dim=self._channel_dim,
op_type=self._op_type,
use_cudnn=self._use_cudnn,
)
return out
class Conv2DTranspose(_ConvNd): class Conv2DTranspose(_ConvNd):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册