未验证 提交 f9b90dda 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Add yaml and unittest for update_loss_scaling and check_finite_and_unscale (#46130)

* add amp yaml

* fix ci bugs
上级 8ff7df8f
...@@ -45,9 +45,10 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']} ...@@ -45,9 +45,10 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
tensor_type_map = { tensor_type_map = {
'const Tensor&': 'const MetaTensor&', 'const Tensor&': 'const MetaTensor&',
'const std::vector<Tensor>&': 'const std::vector<MetaTensor>&', 'const std::vector<Tensor>&':
'const std::vector<const MetaTensor*>&',
'Tensor': 'MetaTensor*', 'Tensor': 'MetaTensor*',
'std::vector<Tensor>': 'std::vector<MetaTensor>*', 'std::vector<Tensor>': 'std::vector<MetaTensor*>',
'const paddle::optional<Tensor>&': 'const MetaTensor&' 'const paddle::optional<Tensor>&': 'const MetaTensor&'
} }
......
...@@ -463,6 +463,18 @@ ...@@ -463,6 +463,18 @@
func : celu func : celu
backward : celu_grad backward : celu_grad
- op : check_finite_and_unscale_
args : (Tensor[] x, Tensor scale, Tensor input_found_infinite)
output : Tensor[](out){x.size()}, Tensor(output_found_infinite)
infer_meta :
func : CheckFiniteAndUnscaleInferMeta
param : [x, scale]
kernel :
func : check_finite_and_unscale
param : [x, scale]
data_type : x
inplace : (x -> out), (input_found_infinite -> output_found_infinite)
- op : class_center_sample - op : class_center_sample
args : (Tensor label, int num_classes, int num_samples, int ring_id, int rank, int nranks, bool fix_seed, int seed) args : (Tensor label, int num_classes, int num_samples, int ring_id, int rank, int nranks, bool fix_seed, int seed)
output : Tensor(remapped_label), Tensor(sampled_local_class_center) output : Tensor(remapped_label), Tensor(sampled_local_class_center)
...@@ -2763,6 +2775,17 @@ ...@@ -2763,6 +2775,17 @@
backend : place backend : place
data_type : dtype data_type : dtype
- op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update)
output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps)
infer_meta :
func : UpdateLossScalingInferMeta
param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps]
kernel :
func : update_loss_scaling
data_type : x
inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps)
- op : unbind - op : unbind
args : (Tensor input, int axis) args : (Tensor input, int axis)
output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]} output : Tensor[] {axis<0 ? input.dims()[input.dims().size()+axis]:input.dims()[axis]}
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core from paddle.fluid import core
from paddle import _C_ops
__all__ = ['check_finite_and_unscale', 'update_loss_scaling'] __all__ = ['check_finite_and_unscale', 'update_loss_scaling']
...@@ -42,8 +43,13 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -42,8 +43,13 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
'check_finite_and_unscale') 'check_finite_and_unscale')
helper = LayerHelper("check_finite_and_unscale", **locals()) helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool') found_inf = helper.create_variable_for_type_inference(dtype='bool')
if in_dygraph_mode():
_C_ops.check_finite_and_unscale_(x, scale, found_inf)
return x, found_inf
inputs = {'X': x, 'Scale': scale} inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status", check_variable_and_dtype(float_status, "float_status",
...@@ -108,6 +114,13 @@ def update_loss_scaling(x, ...@@ -108,6 +114,13 @@ def update_loss_scaling(x,
else: else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
if in_dygraph_mode():
_C_ops.update_loss_scaling_(x, found_inf, prev_loss_scaling,
num_good_steps, num_bad_steps,
incr_every_n_steps, decr_every_n_nan_or_inf,
incr_ratio, decr_ratio, stop_update)
return x
helper = LayerHelper("update_loss_scaling", **locals()) helper = LayerHelper("update_loss_scaling", **locals())
inputs = { inputs = {
......
...@@ -16,12 +16,20 @@ import unittest ...@@ -16,12 +16,20 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn
def check_finite_and_unscale_wrapper(x, scale):
_, found_inf = amp_nn.check_finite_and_unscale([x], scale)
return x, found_inf
class TestCheckFiniteAndUnscaleOp(OpTest): class TestCheckFiniteAndUnscaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "check_finite_and_unscale" self.op_type = "check_finite_and_unscale"
self.python_api = check_finite_and_unscale_wrapper
self.python_out_sig = ["out0", "FoundInfinite"]
self.init_dtype() self.init_dtype()
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
scale = np.random.random((1)).astype(self.dtype) scale = np.random.random((1)).astype(self.dtype)
...@@ -36,7 +44,7 @@ class TestCheckFiniteAndUnscaleOp(OpTest): ...@@ -36,7 +44,7 @@ class TestCheckFiniteAndUnscaleOp(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class TestCheckFiniteAndUnscaleOpWithNan(OpTest): class TestCheckFiniteAndUnscaleOpWithNan(OpTest):
...@@ -44,6 +52,8 @@ class TestCheckFiniteAndUnscaleOpWithNan(OpTest): ...@@ -44,6 +52,8 @@ class TestCheckFiniteAndUnscaleOpWithNan(OpTest):
def setUp(self): def setUp(self):
self.op_type = "check_finite_and_unscale" self.op_type = "check_finite_and_unscale"
self.init_dtype() self.init_dtype()
self.python_api = check_finite_and_unscale_wrapper
self.python_out_sig = ["out0", "FoundInfinite"]
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
x[128][128] = np.nan x[128][128] = np.nan
scale = np.random.random((1)).astype(self.dtype) scale = np.random.random((1)).astype(self.dtype)
...@@ -60,7 +70,7 @@ class TestCheckFiniteAndUnscaleOpWithNan(OpTest): ...@@ -60,7 +70,7 @@ class TestCheckFiniteAndUnscaleOpWithNan(OpTest):
def test_check_output(self): def test_check_output(self):
# When input contains nan, do not check the output, # When input contains nan, do not check the output,
# since the output may be nondeterministic and will be discarded. # since the output may be nondeterministic and will be discarded.
self.check_output(no_check_set=['Out']) self.check_output(no_check_set=['Out'], check_eager=True)
class TestCheckFiniteAndUnscaleOpWithInf(OpTest): class TestCheckFiniteAndUnscaleOpWithInf(OpTest):
...@@ -68,6 +78,8 @@ class TestCheckFiniteAndUnscaleOpWithInf(OpTest): ...@@ -68,6 +78,8 @@ class TestCheckFiniteAndUnscaleOpWithInf(OpTest):
def setUp(self): def setUp(self):
self.op_type = "check_finite_and_unscale" self.op_type = "check_finite_and_unscale"
self.init_dtype() self.init_dtype()
self.python_api = check_finite_and_unscale_wrapper
self.python_out_sig = ["out0", "FoundInfinite"]
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
x[128][128] = np.inf x[128][128] = np.inf
scale = np.random.random((1)).astype(self.dtype) scale = np.random.random((1)).astype(self.dtype)
...@@ -84,7 +96,7 @@ class TestCheckFiniteAndUnscaleOpWithInf(OpTest): ...@@ -84,7 +96,7 @@ class TestCheckFiniteAndUnscaleOpWithInf(OpTest):
def test_check_output(self): def test_check_output(self):
# When input contains inf, do not check the output, # When input contains inf, do not check the output,
# since the output may be nondeterministic and will be discarded. # since the output may be nondeterministic and will be discarded.
self.check_output(no_check_set=['Out']) self.check_output(no_check_set=['Out'], check_eager=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,11 +19,32 @@ import paddle.fluid as fluid ...@@ -19,11 +19,32 @@ import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn
def update_loss_scaling_wrapper(x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False):
amp_nn.update_loss_scaling([x], found_inf, prev_loss_scaling,
num_good_steps, num_bad_steps,
incr_every_n_steps, decr_every_n_nan_or_inf,
incr_ratio, decr_ratio, stop_update)
return x, prev_loss_scaling, num_good_steps, num_bad_steps
class TestUpdateLossScalingOp(OpTest): class TestUpdateLossScalingOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "update_loss_scaling" self.op_type = "update_loss_scaling"
self.init() self.init()
self.python_api = update_loss_scaling_wrapper
self.python_out_sig = [
"out0", "LossScaling", "OutGoodSteps", "OutBadSteps"
]
found_inf = np.array([False], dtype=np.bool_) found_inf = np.array([False], dtype=np.bool_)
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
...@@ -59,7 +80,7 @@ class TestUpdateLossScalingOp(OpTest): ...@@ -59,7 +80,7 @@ class TestUpdateLossScalingOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=['Out']) self.check_output(no_check_set=['Out'], check_eager=True)
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
...@@ -67,6 +88,10 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): ...@@ -67,6 +88,10 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
def setUp(self): def setUp(self):
self.op_type = "update_loss_scaling" self.op_type = "update_loss_scaling"
self.init() self.init()
self.python_api = update_loss_scaling_wrapper
self.python_out_sig = [
"out0", "LossScaling", "OutGoodSteps", "OutBadSteps"
]
found_inf = np.array([True], dtype=np.bool_) found_inf = np.array([True], dtype=np.bool_)
x = np.random.random((1024, 1024)).astype(self.dtype) x = np.random.random((1024, 1024)).astype(self.dtype)
i = np.random.randint(0, 1024, 1) i = np.random.randint(0, 1024, 1)
...@@ -90,7 +115,7 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): ...@@ -90,7 +115,7 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class TestUpdateLossScalingLayer(unittest.TestCase): class TestUpdateLossScalingLayer(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册