如何使用半精度训练?dtype='float16'会报错
Created by: Exception-star
- 问题描述:如何使用半精度训练?,卷积的dtype='float16'
self._conv = fluid.dygraph.Conv2D(self.full_name(),
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
act=None,
dtype='float16',
bias_attr=bias_attr,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.MSRAInitializer(False))
)
def forward(self, inputs):
# inputs.dtype = 'float16'
x = self._conv(inputs)
