diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index 1efcbe4ee88712462c2feb725bcd00c7e648376c..f3ebfb9de10cfc4acdb9364b0bd2f39bcdb7c9af 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -1463,6 +1463,10 @@ def sigmoid_cross_entropy_with_logits(x, ignore_index=-1, normalize=True) print(loss) """ + + if in_dygraph_mode(): + return _C_ops.final_state_sigmoid_cross_entropy_with_logits( + x, label, normalize, int(ignore_index)) check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], 'sigmoid_cross_entropy_with_logits') diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 60064340b198a2be672c99918b6f2cadfaccf02a..cfe0d4e32ef7ac130102a04d785433ce0be2f279 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -2106,7 +2106,7 @@ class OpTest(unittest.TestCase): grad_outputs = [] for grad_out_value in user_defined_grad_outputs: grad_outputs.append(paddle.to_tensor(grad_out_value)) - # delete the inputs which no need to calculate grad + # delete the inputs which no need to calculate grad for no_grad_val in no_grad_set: del (inputs[no_grad_val]) diff --git a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py index 153b8fd3e7f6b0cba199455c28df1ed3cc46ff94..ea6d82d15ce0cf9a4d1725a7580c4370a7e6c336 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py @@ -17,6 +17,7 @@ import paddle.fluid as fluid import numpy as np import unittest from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard def call_bce_layer(logit, label, weight=None, reduction='mean', @@ -81,23 +82,22 @@ def test_dygraph(place, reduction='mean', pos_weight_np=None, functional=False): - paddle.disable_static() - logit = paddle.to_tensor(logit_np) - label = paddle.to_tensor(label_np) - weight = None - pos_weight = None - if weight_np is not None: - weight = paddle.to_tensor(weight_np) - if pos_weight_np is not None: - pos_weight = paddle.to_tensor(pos_weight_np) - if functional: - dy_res = call_bce_functional(logit, label, weight, reduction, - pos_weight) - else: - dy_res = call_bce_layer(logit, label, weight, reduction, pos_weight) - dy_result = dy_res.numpy() - paddle.enable_static() - return dy_result + with paddle.fluid.dygraph.base.guard(): + logit = paddle.to_tensor(logit_np) + label = paddle.to_tensor(label_np) + weight = None + pos_weight = None + if weight_np is not None: + weight = paddle.to_tensor(weight_np) + if pos_weight_np is not None: + pos_weight = paddle.to_tensor(pos_weight_np) + if functional: + dy_res = call_bce_functional(logit, label, weight, reduction, + pos_weight) + else: + dy_res = call_bce_layer(logit, label, weight, reduction, pos_weight) + dy_result = dy_res.numpy() + return dy_result def calc_bce_with_logits_loss(logit_np, @@ -154,9 +154,19 @@ class TestBCEWithLogitsLoss(unittest.TestCase): label_np, reduction=reduction, functional=True) + + with _test_eager_guard(): + eager_functional = test_dygraph( + place, + logit_np, + label_np, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) self.assertTrue(np.allclose(static_functional, dy_functional)) self.assertTrue(np.allclose(dy_functional, expected)) + self.assertTrue(np.allclose(eager_functional, expected)) def test_BCEWithLogitsLoss_weight(self): logit_np = np.random.uniform( diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py index 416a60b8ba20016df7bbb2eef104f68e5fd74d90..3bf6868fed9c943c4c89c9ddeeafee04348e6f1d 100755 --- a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -21,78 +21,63 @@ import paddle import paddle.fluid as fluid -class TestExpandAsOpRank1(OpTest): +class TestExpandAsBasic(OpTest): def setUp(self): self.op_type = "expand_as_v2" self.python_api = paddle.expand_as x = np.random.rand(100).astype("float64") target_tensor = np.random.rand(2, 100).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [2, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) -class TestExpandAsOpRank2(OpTest): +class TestExpandAsOpRank2(TestExpandAsBasic): def setUp(self): self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as x = np.random.rand(10, 12).astype("float64") target_tensor = np.random.rand(10, 12).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [1, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - -class TestExpandAsOpRank3(OpTest): +class TestExpandAsOpRank3(TestExpandAsBasic): def setUp(self): self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as x = np.random.rand(2, 3, 20).astype("float64") target_tensor = np.random.rand(2, 3, 20).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [1, 1, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} - def test_check_output(self): - self.check_output() - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestExpandAsOpRank4(OpTest): +class TestExpandAsOpRank4(TestExpandAsBasic): def setUp(self): self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as x = np.random.rand(1, 1, 7, 16).astype("float64") target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") - self.inputs = {'X': x} + self.inputs = {'X': x, "Y": target_tensor} self.attrs = {'target_shape': target_tensor.shape} bcast_dims = [4, 6, 1, 1] output = np.tile(self.inputs['X'], bcast_dims) self.outputs = {'Out': output} - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - class TestExpandAsV2Error(unittest.TestCase): def test_errors(self): @@ -130,4 +115,5 @@ class TestExpandAsV2API(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py index 51751588f7b94447080f80002ceb29dac2429529..e5406f4d0c22400b6d51caaa136103b0eecee748 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py @@ -22,6 +22,12 @@ import paddle.fluid.core as core import unittest from paddle.fluid import compiler, Program, program_guard import paddle.fluid as fluid +import paddle + + +def test_fluid_sigmoid(x, label, normalize=False, ignore_index=-100): + return paddle.fluid.layers.sigmoid_cross_entropy_with_logits( + x, label, int(ignore_index), normalize=normalize) class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): @@ -30,6 +36,7 @@ class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 self.inputs = { @@ -49,10 +56,10 @@ class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): self.outputs = {'Out': -term1 - term2} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): @@ -61,6 +68,7 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 ignore_index = -1 @@ -83,10 +91,10 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): @@ -95,6 +103,7 @@ class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 self.inputs = { @@ -114,15 +123,16 @@ class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): self.outputs = {'Out': -term1 - term2} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithNorm(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = 64 num_classes = 20 ignore_index = -1 @@ -145,10 +155,10 @@ class TestSigmoidCrossEntropyWithNorm(OpTest): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithLogitsOp5(OpTest): @@ -157,6 +167,7 @@ class TestSigmoidCrossEntropyWithLogitsOp5(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = [10, 10] num_classes = 20 self.inputs = { @@ -176,15 +187,16 @@ class TestSigmoidCrossEntropyWithLogitsOp5(OpTest): self.outputs = {'Out': -term1 - term2} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestSigmoidCrossEntropyWithNorm2(OpTest): def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid batch_size = [10, 10] num_classes = 20 ignore_index = -1 @@ -207,68 +219,71 @@ class TestSigmoidCrossEntropyWithNorm2(OpTest): self.outputs = {'Out': out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSigmoidCrossEntropyWithLogitsOp6(OpTest): - """Test sigmoid_cross_entropy_with_logit_op with binary label - """ - - def setUp(self): - self.op_type = "sigmoid_cross_entropy_with_logits" - batch_size = [10, 10] - num_classes = 20 - self.inputs = { - 'X': logit( - np.random.uniform(0, 1, tuple(batch_size + [num_classes])) - .astype("float64")), - 'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes])) - .astype("float64") - } - - # Fw Pass is implemented as elementwise sigmoid followed by - # elementwise logistic loss - # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) - sigmoid_X = expit(self.inputs['X']) - term1 = self.inputs['Label'] * np.log(sigmoid_X) - term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) - self.outputs = {'Out': -term1 - term2} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSigmoidCrossEntropyWithLogitsOpError(unittest.TestCase): - def test_errors(self): - with program_guard(Program(), Program()): - - def test_Variable(): - # the input of sigmoid_cross_entropy_with_logits must be Variable. - x1 = fluid.create_lod_tensor( - np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) - lab1 = fluid.create_lod_tensor( - np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) - fluid.layers.sigmoid_cross_entropy_with_logits(x1, lab1) - - self.assertRaises(TypeError, test_Variable) - - def test_dtype(): - # the input dtype of sigmoid_cross_entropy_with_logits must be float16 or float32 or float64 - # float16 only can be set on GPU place - x2 = fluid.layers.data( - name='x2', shape=[3, 4, 5, 6], dtype="int32") - lab2 = fluid.layers.data( - name='lab2', shape=[3, 4, 5, 6], dtype="int32") - fluid.layers.sigmoid_cross_entropy_with_logits(x2, lab2) - - self.assertRaises(TypeError, test_dtype) + self.check_grad(['X'], 'Out', check_eager=True) + + class TestSigmoidCrossEntropyWithLogitsOp6(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.python_api = test_fluid_sigmoid + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype("float64")), + 'Label': + np.random.randint(0, 2, tuple(batch_size + [num_classes])) + .astype("float64") + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X'], 'Out', check_eager=True) + + class TestSigmoidCrossEntropyWithLogitsOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + def test_Variable(): + # the input of sigmoid_cross_entropy_with_logits must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], + fluid.CPUPlace()) + lab1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], + fluid.CPUPlace()) + fluid.layers.sigmoid_cross_entropy_with_logits(x1, lab1) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + # the input dtype of sigmoid_cross_entropy_with_logits must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data( + name='x2', shape=[3, 4, 5, 6], dtype="int32") + lab2 = fluid.layers.data( + name='lab2', shape=[3, 4, 5, 6], dtype="int32") + fluid.layers.sigmoid_cross_entropy_with_logits(x2, lab2) + + self.assertRaises(TypeError, test_dtype) if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py index 2ef04d9cbfa73fe21a960e4af13cb7efdee316c7..15a4827cecba344c2004446c308e7336a4ccf3f2 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py @@ -18,6 +18,7 @@ import numpy as np import unittest from op_test import OpTest from test_sigmoid_focal_loss_op import sigmoid_focal_loss_forward +from paddle.fluid.framework import _test_eager_guard def call_sfl_functional(logit, @@ -140,6 +141,10 @@ class TestSigmoidFocalLoss(unittest.TestCase): dy_result = test_dygraph(place, logit_np, label_np, normalizer_np, alpha, gamma, reduction) + with _test_eager_guard(): + eager_result = test_dygraph( + place, logit_np, label_np, normalizer_np, + alpha, gamma, reduction) expected = calc_sigmoid_focal_loss( logit_np, label_np, normalizer_np, alpha, gamma, reduction) @@ -148,6 +153,7 @@ class TestSigmoidFocalLoss(unittest.TestCase): self.assertTrue( np.allclose(static_result, dy_result)) self.assertTrue(np.allclose(dy_result, expected)) + self.assertTrue(np.allclose(eager_result, expected)) def test_SigmoidFocalLoss_error(self): paddle.disable_static() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 8a2b5cbb8b334590ad05140bfd40f5c54d752697..593cea2d2cf643310f9e1b9d7a0b35be679eb6fb 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -259,12 +259,16 @@ def binary_cross_entropy_with_logits(logit, "should be 'sum', 'mean' or 'none', but received %s, which is not allowed." % reduction) - if in_dynamic_mode(): + if _non_static_mode(): one = _varbase_creator(dtype=logit.dtype) _C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False, 'dtype', one.dtype, 'str_value', '1.0', 'shape', [1]) - out = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) + if in_dygraph_mode(): + out = _C_ops.final_state_sigmoid_cross_entropy_with_logits( + logit, label, False, -100) + else: + out = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) if pos_weight is not None: log_weight = _C_ops.elementwise_add( _C_ops.elementwise_mul(label, @@ -2024,12 +2028,16 @@ def sigmoid_focal_loss(logit, "Expected one dimension of normalizer in sigmoid_focal_loss but got {}.". format(normalizer_dims)) - if in_dynamic_mode(): + if _non_static_mode(): one = _varbase_creator(dtype=logit.dtype) _C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False, 'dtype', one.dtype, 'str_value', '1.0', 'shape', logit.shape) - loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) + if in_dygraph_mode(): + loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits( + logit, label, False, -100) + else: + loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) pred = _C_ops.sigmoid(logit) p_t = _C_ops.elementwise_add( _C_ops.elementwise_mul(pred, label), diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index b055abcf845f98375609e4e341e64f6a966ab66a..92fec23c6c769fe69e4d6aae4328e95e1d87b569 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1837,6 +1837,9 @@ def expand_as(x, y, name=None): np_out = out.numpy() # [[1, 2, 3], [1, 2, 3]] """ + if in_dygraph_mode(): + return _C_ops.final_state_expand_as(x, None, y.shape) + if _non_static_mode(): return _C_ops.expand_as_v2(x, 'target_shape', y.shape) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index af4e7a5b3bb32175f7e504d7318c9a483ab91a97..4c17644792fbdeee9e19559d6f50b801be6303ef 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -566,6 +566,17 @@ func : erfinv backward : erfinv_grad +# expand_as +- api : expand_as + args : (Tensor x, Tensor y, int[] target_shape) + output : Tensor + infer_meta : + func : ExpandAsInferMeta + kernel : + func : expand_as + optional : y + backward : expand_as_grad + - api : expm1 args : (Tensor x) output : Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index f94d0a9e50523b1e35179a151fdb15d1e98110e9..da60dae43169542909a4d108d95630eeaf3aa635 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -373,6 +373,16 @@ kernel : func : erfinv_grad +- backward_api : expand_as_grad + forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int[] target_shape) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : expand_as_grad + - backward_api : expm1_grad forward : expm1 (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 74cb6fb0e535659024898ea905c20430db21dd5b..5638cf506c84d5670eb0d4f53d88bccd8c22dcb9 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "nll_loss", "dropout", "flatten"], +"phi_apis":["conj", "nll_loss", "flatten", "expand_as", "dropout"], "phi_kernels":["equal_all"] }