未验证 提交 b96dada4 编写于 作者: H huangxu96 提交者: GitHub

add static.amp into setup.pu.in (#29621)

* add static.amp into setup.pu.in

* add unittest for api
上级 1e9127f6
...@@ -24,6 +24,7 @@ import unittest ...@@ -24,6 +24,7 @@ import unittest
import os import os
import copy import copy
import numpy as np import numpy as np
from paddle.static.amp import decorate
paddle.enable_static() paddle.enable_static()
...@@ -138,7 +139,7 @@ def train(net_type, use_cuda, save_dirname, is_local): ...@@ -138,7 +139,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"}) custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate( mp_optimizer = decorate(
optimizer=optimizer, optimizer=optimizer,
amp_lists=amp_lists, amp_lists=amp_lists,
init_loss_scaling=8.0, init_loss_scaling=8.0,
...@@ -442,7 +443,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): ...@@ -442,7 +443,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
optimizer = fluid.optimizer.Lamb(learning_rate=0.001) optimizer = fluid.optimizer.Lamb(learning_rate=0.001)
amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists( amp_lists = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_black_varnames={"loss", "conv2d_0.w_0"}) custom_black_varnames={"loss", "conv2d_0.w_0"})
mp_optimizer = fluid.contrib.mixed_precision.decorate( mp_optimizer = decorate(
optimizer=optimizer, optimizer=optimizer,
amp_lists=amp_lists, amp_lists=amp_lists,
init_loss_scaling=8.0, init_loss_scaling=8.0,
......
...@@ -19,8 +19,8 @@ import paddle.fluid as fluid ...@@ -19,8 +19,8 @@ import paddle.fluid as fluid
import contextlib import contextlib
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16 from paddle.static.amp import cast_model_to_fp16
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 from paddle.static.amp import cast_parameters_to_fp16
paddle.enable_static() paddle.enable_static()
......
...@@ -24,6 +24,7 @@ __all__ = [ ...@@ -24,6 +24,7 @@ __all__ = [
] ]
from . import nn from . import nn
from . import amp
from .io import save_inference_model #DEFINE_ALIAS from .io import save_inference_model #DEFINE_ALIAS
from .io import load_inference_model #DEFINE_ALIAS from .io import load_inference_model #DEFINE_ALIAS
from ..fluid import Scope #DEFINE_ALIAS from ..fluid import Scope #DEFINE_ALIAS
......
...@@ -210,6 +210,7 @@ packages=['paddle', ...@@ -210,6 +210,7 @@ packages=['paddle',
'paddle.metric', 'paddle.metric',
'paddle.static', 'paddle.static',
'paddle.static.nn', 'paddle.static.nn',
'paddle.static.amp',
'paddle.tensor', 'paddle.tensor',
'paddle.onnx', 'paddle.onnx',
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册