diff --git a/README.md b/README.md index deca21129946ac1748fcb7e7996ea5663d920cab..6b817214d4c1b9af50903a5ea1a83239ab449f0f 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,6 @@ PLSC具备以下特点: ### 高级功能 -* [混合精度训练] +* [混合精度训练](docs/mixed_precision.md) * [分布式参数转换] * [Base64格式图像预处理] diff --git a/docs/mixed_precision.md b/docs/mixed_precision.md new file mode 100755 index 0000000000000000000000000000000000000000..7b3f6d08c7b4635818507c4ed6aba6c1f1ce347b --- /dev/null +++ b/docs/mixed_precision.md @@ -0,0 +1,28 @@ +# 混合精度训练 + +PLSC支持混合精度训练。使用混合精度训练可以提升训练的速度,同时减少训练使用的内存。 + +可以通过下面的代码设置开启混合精度训练: + +```python +from __future__ import print_function +import plsc.entry as entry + +def main(): + ins = entry.Entry() + ins.set_mixed_precision(True, 1.0) + ins.train() + +if __name__ == "__main__": + main() +``` +其中,`set_mixed_precision`函数介绍如下: + +| API | 描述 | 参数说明 | +| :------------------- | :--------------------| :---------------------- | +| set_mixed_precision(use_fp16, loss_scaling) | 设置混合精度训练 | `use_fp16`为是否开启混合精度训练,默认为False;`loss_scaling`为初始的损失缩放值,默认为1.0| + +- `use_fp16`:bool类型,当想要开启混合精度训练时,可将此参数设为True即可。 +- `loss_scaling`:float类型,为初始的损失缩放值,这个值有可能会影响混合精度训练的精度,建议设为默认值1.0。 + +为了提高混合精度训练的稳定性和精度,默认开启了动态损失缩放机制。更多关于混合精度训练的介绍可参考:[混合精度训练](https://arxiv.org/abs/1710.03740)