未验证 提交 b96dada4 编写于 作者: H huangxu96 提交者: GitHub

add static.amp into setup.pu.in (#29621)

* add static.amp into setup.pu.in

* add unittest for api
上级 1e9127f6
......@@ -24,6 +24,7 @@ import unittest
import os
import copy
import numpy as np
from paddle.static.amp import decorate
paddle.enable_static()
......@@ -138,7 +139,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate(
mp_optimizer = decorate(
optimizer=optimizer,
amp_lists=amp_lists,
init_loss_scaling=8.0,
......@@ -442,7 +443,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
optimizer = fluid.optimizer.Lamb(learning_rate=0.001)
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate(
mp_optimizer = decorate(
optimizer=optimizer,
amp_lists=amp_lists,
init_loss_scaling=8.0,
......
......@@ -19,8 +19,8 @@ import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
from paddle.static.amp import cast_model_to_fp16
from paddle.static.amp import cast_parameters_to_fp16
paddle.enable_static()
......
......@@ -24,6 +24,7 @@ __all__ = [
]
from . import nn
from . import amp
from .io import save_inference_model #DEFINE_ALIAS
from .io import load_inference_model #DEFINE_ALIAS
from ..fluid import Scope #DEFINE_ALIAS
......
......@@ -210,6 +210,7 @@ packages=['paddle',
'paddle.metric',
'paddle.static',
'paddle.static.nn',
'paddle.static.amp',
'paddle.tensor',
'paddle.onnx',
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册