未验证 提交 a12c806f 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] Cast,uniform_random,assign_value use final state (#45637)

* [Eager] Cast,uniform_random,assign_value use final state

* fix mistake

* fix mistake

* fix CI errors

* add dygraph test for unform

* add fp16 test case for bilinear_initializer
上级 62033f25
...@@ -30,7 +30,7 @@ from paddle import _C_ops, _legacy_C_ops ...@@ -30,7 +30,7 @@ from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype) check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, elementwise_mul, elementwise_sub, nn, ops,
tensor) tensor)
...@@ -221,8 +221,11 @@ class Distribution(object): ...@@ -221,8 +221,11 @@ class Distribution(object):
warnings.warn( warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
) )
return _legacy_C_ops.cast(value, 'in_dtype', value.dtype, if in_dygraph_mode():
'out_dtype', param.dtype) return _C_ops.cast(value, param.dtype)
if _in_legacy_dygraph():
return _legacy_C_ops.cast(value, 'in_dtype', value.dtype,
'out_dtype', param.dtype)
return value return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'], check_variable_and_dtype(value, 'value', ['float32', 'float64'],
......
...@@ -21,7 +21,7 @@ from paddle.distribution import distribution ...@@ -21,7 +21,7 @@ from paddle.distribution import distribution
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype) check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, elementwise_mul, elementwise_sub, nn, ops,
tensor) tensor)
...@@ -191,11 +191,17 @@ class Uniform(distribution.Distribution): ...@@ -191,11 +191,17 @@ class Uniform(distribution.Distribution):
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, if in_dygraph_mode():
'out_dtype', value.dtype) lb = _C_ops.cast(lb_bool, value.dtype)
ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, ub = _C_ops.cast(ub_bool, value.dtype)
'out_dtype', value.dtype) return nn.log(lb * ub) - nn.log(self.high - self.low)
return nn.log(lb * ub) - nn.log(self.high - self.low)
if _in_legacy_dygraph():
lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype,
'out_dtype', value.dtype)
ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype,
'out_dtype', value.dtype)
return nn.log(lb * ub) - nn.log(self.high - self.low)
name = self.name + '_log_prob' name = self.name + '_log_prob'
lb_bool = self.low < value lb_bool = self.low < value
...@@ -221,11 +227,17 @@ class Uniform(distribution.Distribution): ...@@ -221,11 +227,17 @@ class Uniform(distribution.Distribution):
lb_bool = self.low < value lb_bool = self.low < value
ub_bool = value < self.high ub_bool = value < self.high
lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, if in_dygraph_mode():
'out_dtype', value.dtype) lb = _C_ops.cast(lb_bool, value.dtype)
ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, ub = _C_ops.cast(ub_bool, value.dtype)
'out_dtype', value.dtype) return (lb * ub) / (self.high - self.low)
return (lb * ub) / (self.high - self.low)
if _in_legacy_dygraph():
lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype,
'out_dtype', value.dtype)
ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype,
'out_dtype', value.dtype)
return (lb * ub) / (self.high - self.low)
name = self.name + '_probs' name = self.name + '_probs'
lb_bool = self.low < value lb_bool = self.low < value
......
...@@ -278,13 +278,23 @@ class UniformInitializer(Initializer): ...@@ -278,13 +278,23 @@ class UniformInitializer(Initializer):
out_var = var out_var = var
if framework._non_static_mode(): if framework._non_static_mode():
out_var = _legacy_C_ops.uniform_random( if in_dygraph_mode():
'shape', var.shape, 'min', self._low, 'max', self._high, 'seed', out_var = _C_ops.uniform_random(var.shape, out_dtype, self._low,
self._seed, 'dtype', out_dtype, 'diag_num', self._diag_num, self._high, self._seed,
'diag_step', self._diag_step, 'diag_val', self._diag_val) _current_expected_place())
elif _in_legacy_dygraph():
out_var = _legacy_C_ops.uniform_random(
'shape', var.shape, 'min', self._low, 'max', self._high,
'seed', self._seed, 'dtype', out_dtype, 'diag_num',
self._diag_num, 'diag_step', self._diag_step, 'diag_val',
self._diag_val)
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16:
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, if in_dygraph_mode():
'out_dtype', var.dtype) var_tmp = _C_ops.cast(out_var, var.dtype)
elif _in_legacy_dygraph():
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype',
out_var.dtype, 'out_dtype',
var.dtype)
var_tmp._share_underline_tensor_to(var) var_tmp._share_underline_tensor_to(var)
else: else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
...@@ -828,8 +838,12 @@ class MSRAInitializer(Initializer): ...@@ -828,8 +838,12 @@ class MSRAInitializer(Initializer):
if var.dtype == VarDesc.VarType.FP16 or ( if var.dtype == VarDesc.VarType.FP16 or (
var.dtype == VarDesc.VarType.BF16 and not self._uniform): var.dtype == VarDesc.VarType.BF16 and not self._uniform):
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, if in_dygraph_mode():
'out_dtype', var.dtype) var_tmp = _C_ops.cast(out_var, var.dtype)
elif _in_legacy_dygraph():
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype',
out_var.dtype, 'out_dtype',
var.dtype)
var_tmp._share_underline_tensor_to(var) var_tmp._share_underline_tensor_to(var)
else: else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
...@@ -989,14 +1003,23 @@ class BilinearInitializer(Initializer): ...@@ -989,14 +1003,23 @@ class BilinearInitializer(Initializer):
raise ValueError("The size of input is too big. ") raise ValueError("The size of input is too big. ")
if framework._non_static_mode(): if framework._non_static_mode():
_legacy_C_ops.assign_value(out_var, 'shape', list(shape), 'dtype', if in_dygraph_mode():
out_dtype, value_name, values) _C_ops.assign_value_(out_var, list(shape), out_dtype, values,
_current_expected_place())
elif _in_legacy_dygraph():
_legacy_C_ops.assign_value(out_var, 'shape', list(shape),
'dtype', out_dtype, value_name,
values)
if var.dtype in [ if var.dtype in [
VarDesc.VarType.FP16, VarDesc.VarType.BF16, VarDesc.VarType.FP16, VarDesc.VarType.BF16,
VarDesc.VarType.FP64 VarDesc.VarType.FP64
]: ]:
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, if in_dygraph_mode():
'out_dtype', var.dtype) var_tmp = _C_ops.cast(out_var, var.dtype)
elif _in_legacy_dygraph():
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype',
out_var.dtype, 'out_dtype',
var.dtype)
var_tmp._share_underline_tensor_to(var) var_tmp._share_underline_tensor_to(var)
else: else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
...@@ -1096,12 +1119,21 @@ class NumpyArrayInitializer(Initializer): ...@@ -1096,12 +1119,21 @@ class NumpyArrayInitializer(Initializer):
"saving it to file and 'load_op' to load it") "saving it to file and 'load_op' to load it")
if framework._non_static_mode(): if framework._non_static_mode():
_legacy_C_ops.assign_value(out_var, 'shape', if in_dygraph_mode():
list(self._value.shape), 'dtype', _C_ops.assign_value_(out_var,
out_dtype, value_name, values) list(self._value.shape), out_dtype, values,
_current_expected_place())
elif _in_legacy_dygraph():
_legacy_C_ops.assign_value(out_var, 'shape',
list(self._value.shape), 'dtype',
out_dtype, value_name, values)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, if in_dygraph_mode():
'out_dtype', var.dtype) var_tmp = _C_ops.cast(out_var, var.dtype)
elif _in_legacy_dygraph():
var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype',
out_var.dtype, 'out_dtype',
var.dtype)
var_tmp._share_underline_tensor_to(var) var_tmp._share_underline_tensor_to(var)
else: else:
out_var._share_underline_tensor_to(var) out_var._share_underline_tensor_to(var)
......
...@@ -20,6 +20,7 @@ import paddle ...@@ -20,6 +20,7 @@ import paddle
from paddle import fluid from paddle import fluid
from paddle.distribution import * from paddle.distribution import *
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.framework import _test_eager_guard
from test_distribution import DistributionNumpy from test_distribution import DistributionNumpy
...@@ -114,17 +115,6 @@ class UniformTest(unittest.TestCase): ...@@ -114,17 +115,6 @@ class UniformTest(unittest.TestCase):
atol=tolerance) atol=tolerance)
np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance) np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance)
def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
paddle.disable_static(self.place)
uniform = Uniform(self.dynamic_low, self.dynamic_high)
sample = uniform.sample([sample_shape]).numpy()
entropy = uniform.entropy().numpy()
log_prob = uniform.log_prob(self.dynamic_values).numpy()
probs = uniform.probs(self.dynamic_values).numpy()
fetch_list = [sample, entropy, log_prob, probs]
self.compare_with_numpy(fetch_list)
def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6): def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6):
paddle.enable_static() paddle.enable_static()
with fluid.program_guard(self.test_program): with fluid.program_guard(self.test_program):
...@@ -148,6 +138,24 @@ class UniformTest(unittest.TestCase): ...@@ -148,6 +138,24 @@ class UniformTest(unittest.TestCase):
self.compare_with_numpy(fetch_list) self.compare_with_numpy(fetch_list)
def func_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
paddle.disable_static()
uniform = Uniform(self.dynamic_low, self.dynamic_high)
sample = uniform.sample([sample_shape]).numpy()
entropy = uniform.entropy().numpy()
log_prob = uniform.log_prob(self.dynamic_values).numpy()
probs = uniform.probs(self.dynamic_values).numpy()
fetch_list = [sample, entropy, log_prob, probs]
self.compare_with_numpy(fetch_list)
def test_uniform_distribution_dygraph(self):
with _test_eager_guard():
self.setUp()
self.func_uniform_distribution_dygraph()
self.setUp()
self.func_uniform_distribution_dygraph()
class UniformTest2(UniformTest): class UniformTest2(UniformTest):
......
...@@ -23,6 +23,7 @@ import paddle.fluid as fluid ...@@ -23,6 +23,7 @@ import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.initializer as initializer import paddle.fluid.initializer as initializer
from paddle.fluid.core import VarDesc from paddle.fluid.core import VarDesc
from paddle.regularizer import L2Decay
DELTA = 0.00001 DELTA = 0.00001
...@@ -563,6 +564,55 @@ class TestBilinearInitializer(unittest.TestCase): ...@@ -563,6 +564,55 @@ class TestBilinearInitializer(unittest.TestCase):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32') self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')
class TestBilinearInitializerDygraphAPI(unittest.TestCase):
def func_test_case(self):
factor = 2
C = 2
B = 8
H = W = 32
w_attr = paddle.ParamAttr(learning_rate=0.,
regularizer=L2Decay(0.),
initializer=initializer.BilinearInitializer())
data = paddle.rand([B, 3, H, W], dtype='float32')
conv_up = paddle.nn.Conv2DTranspose(3,
out_channels=C,
kernel_size=2 * factor - factor % 2,
padding=int(
math.ceil((factor - 1) / 2.)),
stride=factor,
weight_attr=w_attr,
bias_attr=False)
x = conv_up(data)
return x
def func_test_case_fp16(self):
paddle.set_default_dtype("float16")
paddle.seed(1234)
w_attr = paddle.ParamAttr(learning_rate=0.,
regularizer=L2Decay(0.),
initializer=initializer.BilinearInitializer())
conv2d = paddle.nn.Conv2D(1, 2, 3, weight_attr=w_attr)
paddle.set_default_dtype("float32")
return conv2d.weight
def test_bilinear_initializer(self):
paddle.disable_static()
with framework._test_eager_guard():
eager_x = self.func_test_case()
legacy_x = self.func_test_case()
self.assertEqual(eager_x.numpy().all(), legacy_x.numpy().all())
paddle.enable_static()
def test_bilinear_initializer_fp16(self):
paddle.disable_static()
with framework._test_eager_guard():
eager_x = self.func_test_case_fp16()
legacy_x = self.func_test_case_fp16()
self.assertEqual(eager_x.numpy().all(), legacy_x.numpy().all())
paddle.enable_static()
class TestNumpyArrayInitializer(unittest.TestCase): class TestNumpyArrayInitializer(unittest.TestCase):
def test_numpy_array_initializer(self, dtype="float32"): def test_numpy_array_initializer(self, dtype="float32"):
......
...@@ -198,9 +198,9 @@ class Dirac(Initializer): ...@@ -198,9 +198,9 @@ class Dirac(Initializer):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
with fluid.dygraph.no_grad(): with fluid.dygraph.no_grad():
tmp_tensor = framework._varbase_creator() tmp_tensor = framework._varbase_creator()
_legacy_C_ops.assign_value(tmp_tensor, 'shape', [len(idx_list)], _C_ops.assign_value_(tmp_tensor, [len(idx_list)],
'dtype', VarDesc.VarType.INT64, VarDesc.VarType.INT64, idx_list,
'int64_values', idx_list) _current_expected_place())
tmp_tensor._share_underline_tensor_to(index_tensor) tmp_tensor._share_underline_tensor_to(index_tensor)
else: else:
block.append_op(type='assign_value', block.append_op(type='assign_value',
...@@ -220,10 +220,10 @@ class Dirac(Initializer): ...@@ -220,10 +220,10 @@ class Dirac(Initializer):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
with fluid.dygraph.no_grad(): with fluid.dygraph.no_grad():
tmp_tensor = framework._varbase_creator() tmp_tensor = framework._varbase_creator()
_legacy_C_ops.assign_value(tmp_tensor, 'shape', _C_ops.assign_value_(tmp_tensor, [len(value_list)],
[len(value_list)], 'dtype', VarDesc.VarType.FP32, value_list,
VarDesc.VarType.FP32, 'fp32_values', _current_expected_place())
value_list)
tmp_tensor._share_underline_tensor_to(value_tensor) tmp_tensor._share_underline_tensor_to(value_tensor)
else: else:
block.append_op(type='assign_value', block.append_op(type='assign_value',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册