diff --git a/python/paddle/distribution/distribution.py b/python/paddle/distribution/distribution.py index 937b0171722fd72a8f7cbb43c9758050e71426da..8c5843521b0173736f824b1f900a5ab1462452ef 100644 --- a/python/paddle/distribution/distribution.py +++ b/python/paddle/distribution/distribution.py @@ -30,7 +30,7 @@ from paddle import _C_ops, _legacy_C_ops from paddle.fluid import core from paddle.fluid.data_feeder import (check_dtype, check_type, check_variable_and_dtype, convert_dtype) -from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, elementwise_mul, elementwise_sub, nn, ops, tensor) @@ -221,8 +221,11 @@ class Distribution(object): warnings.warn( "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." ) - return _legacy_C_ops.cast(value, 'in_dtype', value.dtype, - 'out_dtype', param.dtype) + if in_dygraph_mode(): + return _C_ops.cast(value, param.dtype) + if _in_legacy_dygraph(): + return _legacy_C_ops.cast(value, 'in_dtype', value.dtype, + 'out_dtype', param.dtype) return value check_variable_and_dtype(value, 'value', ['float32', 'float64'], diff --git a/python/paddle/distribution/uniform.py b/python/paddle/distribution/uniform.py index aa7f0bde4c830b1f731850f311838703b8397fc0..7c085da3156866a5660c358bfd7aea2531259ca9 100644 --- a/python/paddle/distribution/uniform.py +++ b/python/paddle/distribution/uniform.py @@ -21,7 +21,7 @@ from paddle.distribution import distribution from paddle.fluid import core from paddle.fluid.data_feeder import (check_dtype, check_type, check_variable_and_dtype, convert_dtype) -from paddle.fluid.framework import _non_static_mode, in_dygraph_mode +from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, elementwise_mul, elementwise_sub, nn, ops, tensor) @@ -191,11 +191,17 @@ class Uniform(distribution.Distribution): lb_bool = self.low < value ub_bool = value < self.high - lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, - 'out_dtype', value.dtype) - ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, - 'out_dtype', value.dtype) - return nn.log(lb * ub) - nn.log(self.high - self.low) + if in_dygraph_mode(): + lb = _C_ops.cast(lb_bool, value.dtype) + ub = _C_ops.cast(ub_bool, value.dtype) + return nn.log(lb * ub) - nn.log(self.high - self.low) + + if _in_legacy_dygraph(): + lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, + 'out_dtype', value.dtype) + ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, + 'out_dtype', value.dtype) + return nn.log(lb * ub) - nn.log(self.high - self.low) name = self.name + '_log_prob' lb_bool = self.low < value @@ -221,11 +227,17 @@ class Uniform(distribution.Distribution): lb_bool = self.low < value ub_bool = value < self.high - lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, - 'out_dtype', value.dtype) - ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, - 'out_dtype', value.dtype) - return (lb * ub) / (self.high - self.low) + if in_dygraph_mode(): + lb = _C_ops.cast(lb_bool, value.dtype) + ub = _C_ops.cast(ub_bool, value.dtype) + return (lb * ub) / (self.high - self.low) + + if _in_legacy_dygraph(): + lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, + 'out_dtype', value.dtype) + ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, + 'out_dtype', value.dtype) + return (lb * ub) / (self.high - self.low) name = self.name + '_probs' lb_bool = self.low < value diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 2dd2a74000bfaa41a0559119a5fd610e8957ddcc..215f6a5330d25e210ff38e0a5a915b707f0ce825 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -278,13 +278,23 @@ class UniformInitializer(Initializer): out_var = var if framework._non_static_mode(): - out_var = _legacy_C_ops.uniform_random( - 'shape', var.shape, 'min', self._low, 'max', self._high, 'seed', - self._seed, 'dtype', out_dtype, 'diag_num', self._diag_num, - 'diag_step', self._diag_step, 'diag_val', self._diag_val) + if in_dygraph_mode(): + out_var = _C_ops.uniform_random(var.shape, out_dtype, self._low, + self._high, self._seed, + _current_expected_place()) + elif _in_legacy_dygraph(): + out_var = _legacy_C_ops.uniform_random( + 'shape', var.shape, 'min', self._low, 'max', self._high, + 'seed', self._seed, 'dtype', out_dtype, 'diag_num', + self._diag_num, 'diag_step', self._diag_step, 'diag_val', + self._diag_val) if var.dtype == VarDesc.VarType.FP16: - var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, - 'out_dtype', var.dtype) + if in_dygraph_mode(): + var_tmp = _C_ops.cast(out_var, var.dtype) + elif _in_legacy_dygraph(): + var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', + out_var.dtype, 'out_dtype', + var.dtype) var_tmp._share_underline_tensor_to(var) else: out_var._share_underline_tensor_to(var) @@ -828,8 +838,12 @@ class MSRAInitializer(Initializer): if var.dtype == VarDesc.VarType.FP16 or ( var.dtype == VarDesc.VarType.BF16 and not self._uniform): - var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, - 'out_dtype', var.dtype) + if in_dygraph_mode(): + var_tmp = _C_ops.cast(out_var, var.dtype) + elif _in_legacy_dygraph(): + var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', + out_var.dtype, 'out_dtype', + var.dtype) var_tmp._share_underline_tensor_to(var) else: out_var._share_underline_tensor_to(var) @@ -989,14 +1003,23 @@ class BilinearInitializer(Initializer): raise ValueError("The size of input is too big. ") if framework._non_static_mode(): - _legacy_C_ops.assign_value(out_var, 'shape', list(shape), 'dtype', - out_dtype, value_name, values) + if in_dygraph_mode(): + _C_ops.assign_value_(out_var, list(shape), out_dtype, values, + _current_expected_place()) + elif _in_legacy_dygraph(): + _legacy_C_ops.assign_value(out_var, 'shape', list(shape), + 'dtype', out_dtype, value_name, + values) if var.dtype in [ VarDesc.VarType.FP16, VarDesc.VarType.BF16, VarDesc.VarType.FP64 ]: - var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, - 'out_dtype', var.dtype) + if in_dygraph_mode(): + var_tmp = _C_ops.cast(out_var, var.dtype) + elif _in_legacy_dygraph(): + var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', + out_var.dtype, 'out_dtype', + var.dtype) var_tmp._share_underline_tensor_to(var) else: out_var._share_underline_tensor_to(var) @@ -1096,12 +1119,21 @@ class NumpyArrayInitializer(Initializer): "saving it to file and 'load_op' to load it") if framework._non_static_mode(): - _legacy_C_ops.assign_value(out_var, 'shape', - list(self._value.shape), 'dtype', - out_dtype, value_name, values) + if in_dygraph_mode(): + _C_ops.assign_value_(out_var, + list(self._value.shape), out_dtype, values, + _current_expected_place()) + elif _in_legacy_dygraph(): + _legacy_C_ops.assign_value(out_var, 'shape', + list(self._value.shape), 'dtype', + out_dtype, value_name, values) if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: - var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', out_var.dtype, - 'out_dtype', var.dtype) + if in_dygraph_mode(): + var_tmp = _C_ops.cast(out_var, var.dtype) + elif _in_legacy_dygraph(): + var_tmp = _legacy_C_ops.cast(out_var, 'in_dtype', + out_var.dtype, 'out_dtype', + var.dtype) var_tmp._share_underline_tensor_to(var) else: out_var._share_underline_tensor_to(var) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py index c592e53e9befc8194537ca3a88b06db6bd947aa2..329fda4fb3feb3fe9c351b2d3d350fb38e9a1231 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_uniform.py @@ -20,6 +20,7 @@ import paddle from paddle import fluid from paddle.distribution import * from paddle.fluid import layers +from paddle.fluid.framework import _test_eager_guard from test_distribution import DistributionNumpy @@ -114,17 +115,6 @@ class UniformTest(unittest.TestCase): atol=tolerance) np.testing.assert_allclose(probs, np_p, rtol=tolerance, atol=tolerance) - def test_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): - paddle.disable_static(self.place) - uniform = Uniform(self.dynamic_low, self.dynamic_high) - sample = uniform.sample([sample_shape]).numpy() - entropy = uniform.entropy().numpy() - log_prob = uniform.log_prob(self.dynamic_values).numpy() - probs = uniform.probs(self.dynamic_values).numpy() - fetch_list = [sample, entropy, log_prob, probs] - - self.compare_with_numpy(fetch_list) - def test_uniform_distribution_static(self, sample_shape=7, tolerance=1e-6): paddle.enable_static() with fluid.program_guard(self.test_program): @@ -148,6 +138,24 @@ class UniformTest(unittest.TestCase): self.compare_with_numpy(fetch_list) + def func_uniform_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): + paddle.disable_static() + uniform = Uniform(self.dynamic_low, self.dynamic_high) + sample = uniform.sample([sample_shape]).numpy() + entropy = uniform.entropy().numpy() + log_prob = uniform.log_prob(self.dynamic_values).numpy() + probs = uniform.probs(self.dynamic_values).numpy() + fetch_list = [sample, entropy, log_prob, probs] + + self.compare_with_numpy(fetch_list) + + def test_uniform_distribution_dygraph(self): + with _test_eager_guard(): + self.setUp() + self.func_uniform_distribution_dygraph() + self.setUp() + self.func_uniform_distribution_dygraph() + class UniformTest2(UniformTest): diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index 767441e55a49896f4894cecd610758be6750ce19..5ab219df5ec8b8bb988ac881fa1abff8ab6bf5e2 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -23,6 +23,7 @@ import paddle.fluid as fluid import paddle.fluid.framework as framework import paddle.fluid.initializer as initializer from paddle.fluid.core import VarDesc +from paddle.regularizer import L2Decay DELTA = 0.00001 @@ -563,6 +564,55 @@ class TestBilinearInitializer(unittest.TestCase): self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32') +class TestBilinearInitializerDygraphAPI(unittest.TestCase): + + def func_test_case(self): + factor = 2 + C = 2 + B = 8 + H = W = 32 + w_attr = paddle.ParamAttr(learning_rate=0., + regularizer=L2Decay(0.), + initializer=initializer.BilinearInitializer()) + data = paddle.rand([B, 3, H, W], dtype='float32') + conv_up = paddle.nn.Conv2DTranspose(3, + out_channels=C, + kernel_size=2 * factor - factor % 2, + padding=int( + math.ceil((factor - 1) / 2.)), + stride=factor, + weight_attr=w_attr, + bias_attr=False) + x = conv_up(data) + return x + + def func_test_case_fp16(self): + paddle.set_default_dtype("float16") + paddle.seed(1234) + w_attr = paddle.ParamAttr(learning_rate=0., + regularizer=L2Decay(0.), + initializer=initializer.BilinearInitializer()) + conv2d = paddle.nn.Conv2D(1, 2, 3, weight_attr=w_attr) + paddle.set_default_dtype("float32") + return conv2d.weight + + def test_bilinear_initializer(self): + paddle.disable_static() + with framework._test_eager_guard(): + eager_x = self.func_test_case() + legacy_x = self.func_test_case() + self.assertEqual(eager_x.numpy().all(), legacy_x.numpy().all()) + paddle.enable_static() + + def test_bilinear_initializer_fp16(self): + paddle.disable_static() + with framework._test_eager_guard(): + eager_x = self.func_test_case_fp16() + legacy_x = self.func_test_case_fp16() + self.assertEqual(eager_x.numpy().all(), legacy_x.numpy().all()) + paddle.enable_static() + + class TestNumpyArrayInitializer(unittest.TestCase): def test_numpy_array_initializer(self, dtype="float32"): diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index b0ad44b41ef7b578f50e7d32918b125721905801..3d6cfc009cafca25c1e60dfaca5135375db9bdfb 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -198,9 +198,9 @@ class Dirac(Initializer): if framework.in_dygraph_mode(): with fluid.dygraph.no_grad(): tmp_tensor = framework._varbase_creator() - _legacy_C_ops.assign_value(tmp_tensor, 'shape', [len(idx_list)], - 'dtype', VarDesc.VarType.INT64, - 'int64_values', idx_list) + _C_ops.assign_value_(tmp_tensor, [len(idx_list)], + VarDesc.VarType.INT64, idx_list, + _current_expected_place()) tmp_tensor._share_underline_tensor_to(index_tensor) else: block.append_op(type='assign_value', @@ -220,10 +220,10 @@ class Dirac(Initializer): if framework.in_dygraph_mode(): with fluid.dygraph.no_grad(): tmp_tensor = framework._varbase_creator() - _legacy_C_ops.assign_value(tmp_tensor, 'shape', - [len(value_list)], 'dtype', - VarDesc.VarType.FP32, 'fp32_values', - value_list) + _C_ops.assign_value_(tmp_tensor, [len(value_list)], + VarDesc.VarType.FP32, value_list, + _current_expected_place()) + tmp_tensor._share_underline_tensor_to(value_tensor) else: block.append_op(type='assign_value',