diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index ccd184e15c0cef42880a5893efa4e60b0bf51c2a..bf10ee5c37730b70f35914031d676855150df0b4 100755 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -260,7 +260,8 @@ def generate_activation_fn(op_type): if in_dygraph_mode(): if _in_eager_mode(): op = getattr(_C_ops, "final_state_" + op_type) - return op(x) + if op: + return op(x) op = getattr(_C_ops, op_type) return op(x) diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index 2d2ae9aa31c04f0519437c08752a0161f8c0d5b2..f1bc2a352620f4af291ad8481faf5d970bab79a2 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -21,7 +21,7 @@ from paddle.utils import deprecated from . import nn from .layer_function_generator import templatedoc from ..layer_helper import LayerHelper -from ..framework import Variable, in_dygraph_mode, static_only, in_dygraph_mode +from ..framework import Variable, in_dygraph_mode, static_only, in_dygraph_mode, _in_eager_mode from .. import core from ..data_feeder import check_variable_and_dtype, check_type from ..param_attr import ParamAttr diff --git a/python/paddle/fluid/layers/metric_op.py b/python/paddle/fluid/layers/metric_op.py index d6aead1d8df4c9530b9d8f5c1ae6694cd8088481..37a321173ba9b6b98bf966f2673b2482f0ace1ba 100644 --- a/python/paddle/fluid/layers/metric_op.py +++ b/python/paddle/fluid/layers/metric_op.py @@ -87,9 +87,9 @@ def accuracy(input, label, k=1, correct=None, total=None): _k = k.numpy().item(0) if isinstance(k, Variable) else k topk_out, topk_indices = _C_ops.top_k_v2(input, 'k', _k, 'sorted', False) - if _in_eager_mode(): - _acc = _C_ops.final_state_accuracy(topk_out, topk_indices, label) - return _acc + # if _in_eager_mode(): + # _acc = _C_ops.final_state_accuracy(topk_out, topk_indices, label) + # return _acc _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct, total) return _acc diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 673333fdff78b6dfaeb924155a5809008e075acd..73f73ad16399d8a502546c167a796f927d702635 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1501,6 +1501,7 @@ class OpTest(unittest.TestCase): .recursive_sequence_lengths(), expect[1], "Output (" + out_name + ") has different lod at " + str(place) + " in eager dygraph mode") + if check_eager: with fluid.dygraph.base.guard(): with _test_eager_guard(): self.assertListEqual( diff --git a/python/paddle/fluid/tests/unittests/test_filter_by_instag_op.py b/python/paddle/fluid/tests/unittests/test_filter_by_instag_op.py index ecd2e2cd6c3cfc7f161cb53588c391c29934798c..6f139cb54d3d85455fa454bc4552cd292e66a28f 100644 --- a/python/paddle/fluid/tests/unittests/test_filter_by_instag_op.py +++ b/python/paddle/fluid/tests/unittests/test_filter_by_instag_op.py @@ -285,4 +285,5 @@ class TestFilterByInstagOp7(OpTest): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py index 3fe8bca2f192e959e2364912af657134ee123443..8eac60db734566a11c2224c8485a81653146a14d 100644 --- a/python/paddle/fluid/tests/unittests/test_hash_op.py +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -16,6 +16,7 @@ import unittest import numpy as np from op_test import OpTest import paddle.fluid as fluid +import paddle class TestHashOp(OpTest): @@ -140,4 +141,5 @@ class TestHashOpError(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index aac904dc2e15d47d2d2439142363afcaae9e2d67..8335ff868d0320f7801183e708fe48bf3186cfef 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -460,4 +460,5 @@ class API_TestDygraphSplit(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/metric/metrics.py b/python/paddle/metric/metrics.py index 67ff62355cc6fd8e5f60211edb9b346c76f8b0e0..3ff91aa077954510a8e38bdee06b03968796b0f0 100644 --- a/python/paddle/metric/metrics.py +++ b/python/paddle/metric/metrics.py @@ -798,9 +798,9 @@ def accuracy(input, label, k=1, correct=None, total=None, name=None): total = _varbase_creator(dtype="int32") topk_out, topk_indices = paddle.topk(input, k=k) - if _in_eager_mode(): - _acc = _C_ops.final_state_accuracy(topk_out, topk_indices, label) - return _acc + # if _in_eager_mode(): + # _acc = _C_ops.final_state_accuracy(topk_out, topk_indices, label) + # return _acc _acc, _, _ = _C_ops.accuracy(topk_out, topk_indices, label, correct, total) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index fc4d882ff001c52836ca756e36438b0f8617c248..084e8ec2c2da2bb4ed278f95e5b4bca6ec92a1f8 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -948,9 +948,8 @@ def split(x, num_or_sections, axis=0, name=None): print(out1.shape) # [3, 3, 5] print(out2.shape) # [3, 3, 5] """ - if paddle.in_dygraph_mode(): - if _in_eager_mode(): - return _C_ops.final_state_split(x, num_or_sections, dim) + if paddle.in_dynamic_mode() and _in_eager_mode(): + return _C_ops.final_state_split(x, num_or_sections, dim) return paddle.fluid.layers.split( input=x, num_or_sections=num_or_sections, dim=axis, name=name) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 94ca3a0bfd0901fc777d6df51f2bc63485be1117..83ea2ae5a42404715c55a285b60541b50b333c4a 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -166,6 +166,7 @@ kernel : func : relu inplace : (x -> out) + backward : relu_grad - api : scale args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) @@ -191,7 +192,8 @@ infer_meta : func : SoftmaxInferMeta kernel : - func : sotfmax + func : softmax + backward : softmax_grad - api : split args : (Tensor x, ScalarArray num_or_sections, Scalar axis) @@ -342,15 +344,15 @@ backward : segment_pool_grad -# accuracy -- api : accuracy - args : (Tensor x, Tensor indices, Tensor label) - output : Tensor(accuracy), Tensor(correct), Tensor(total) - infer_meta : - func : AccuracyInferMeta - kernel : - func : accuracy - dtype : x +# # accuracy +# - api : accuracy +# args : (Tensor x, Tensor indices, Tensor label) +# output : Tensor(accuracy), Tensor(correct), Tensor(total) +# infer_meta : +# func : AccuracyInferMeta +# kernel : +# func : accuracy +# dtype : x # sin - api : sin @@ -475,6 +477,126 @@ func : sigmoid backward : sigmoid_grad +# tan +- api : tan + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : tan + backward : tan_grad + +# tanh_shrink +- api : tanh_shrink + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : tanh_shrink + backward : tanh_shrink_grad + +# silu +- api : silu + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : silu + backward : silu_grad + +# logsigmoid +- api : logsigmoid + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : logsigmoid + backward : logsigmoid_grad + +# leaky_relu +- api : leaky_relu + args : (Tensor x, float alpha) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : leaky_relu + backward : leaky_relu_grad + +# thresholded_relu +- api : thresholded_relu + args : (Tensor x, float threshold) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : thresholded_relu + backward : thresholded_relu_grad + + +# soft_shrink +- api : soft_shrink + args : (Tensor x, float lambda) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : soft_shrink + backward : soft_shrink_grad + +# hard_shrink +- api : hard_shrink + args : (Tensor x, float threshold) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hard_shrink + backward : hard_shrink_grad + + +# elu +- api : elu + args : (Tensor x, float alpha) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : elu + backward : elu_grad + +# brelu +- api : brelu + args : (Tensor x, float t_min, float t_max) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : brelu + backward : brelu_grad + +# hard_sigmoid +- api : hard_sigmoid + args : (Tensor x, float slope, float offset) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hard_sigmoid + backward : hard_sigmoid_grad + + # arg_min # int64 ???? dtype - api : argmin args : (Tensor x, int64 axis, bool keepdims, bool flatten, int dtype) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 30d782e6539bf3572f57f1f1c8ecea031647454e..42b5b13777dc76002f5323f0de7323ca85618aae 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -75,11 +75,11 @@ kernel : func : diagonal_grad -# - backward_api : split_grad -# forward : split (Tensor x, ScalarArray num_or_sections, Scalar axis) -> Tensor[](out) -# args : (Tensor[] out_grad, Scalar axis) -# output : Tensor(x_grad) -# invoke : concat( out_grad, axis) +- backward_api : split_grad + forward : split (Tensor x, ScalarArray num_or_sections, Scalar axis) -> Tensor[](out) + args : (Tensor[] out_grad, Scalar axis) + output : Tensor(x_grad) + invoke : concat( out_grad, axis) # TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future. # - backward_api : matmul_triple_grad @@ -165,11 +165,11 @@ - backward_api : cos_grad forward : cos (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : cos_grad @@ -185,91 +185,91 @@ - backward_api : acos_grad forward : acos (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : acos_grad - backward_api : sin_grad forward : sin (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : sin_grad - backward_api : asin_grad forward : asin (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : asin_grad - backward_api : atan_grad forward : atan (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : atan_grad - backward_api : sinh_grad forward : sinh (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : sinh_grad - backward_api : cosh_grad forward : cosh (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : cosh_grad - backward_api : asinh_grad forward : asinh (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : asinh_grad - backward_api : acosh_grad forward : acosh (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : acosh_grad - backward_api : atanh_grad forward : atanh (Tensor x) -> Tensor(out) - args : (Tensor out, Tensor out_grad) + args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out] + param : [x] kernel : func : atanh_grad @@ -293,6 +293,122 @@ kernel : func : sigmoid_grad +- backward_api : tan_grad + forward : tan (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : relu_grad + +- backward_api : tanh_shrink_grad + forward : tanh_shrink (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : tanh_shrink_grad + + +- backward_api : silu_grad + forward : silu (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : silu_grad + +- backward_api : logsigmoid_grad + forward : logsigmoid (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : logsigmoid_grad + +- backward_api : leaky_relu_grad + forward : leaky_relu (Tensor x, float alpha) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float alpha) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : leaky_relu_grad + +- backward_api : thresholded_relu_grad + forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float threshold) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : thresholded_relu_grad + +- backward_api : soft_shrink_grad + forward : soft_shrink (Tensor x, float lambda) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float lambda) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : soft_shrink_grad + + +- backward_api : hard_shrink_grad + forward : hard_shrink (Tensor x, float threshold) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float threshold) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : hard_shrink_grad + + +- backward_api : elu_grad + forward : elu (Tensor x, float alpha) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, float alpha) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : elu_grad + + +- backward_api : brelu_grad + forward : brelu (Tensor x, float t_min, float t_max) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float t_min, float t_max) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : brelu_grad + +- backward_api : hard_sigmoid_grad + forward : hard_sigmoid (Tensor x, float slope, float offset) -> Tensor(out) + args : (Tensor out, Tensor out_grad, float slope, float offset) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : hard_sigmoid_grad + + + - backward_api : argsort_grad forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices) args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending)