diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 4e026dab662c04b6edaf605b0f9f0375c9a7d418..3f9d5fb97973f6e542e3012f7807a4d1ca04621b 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -176,24 +176,20 @@ def declarative(function=None, input_spec=None): Examples: .. code-block:: python - import paddle.fluid as fluid - import numpy as np - from paddle.fluid.dygraph.jit import declarative - - fluid.enable_dygraph() - - @declarative - def func(x): - x = fluid.dygraph.to_variable(x) - if fluid.layers.mean(x) < 0: - x_v = x - 1 - else: - x_v = x + 1 - return x_v - - x = np.ones([1, 2]) - x_v = func(x) - print(x_v.numpy()) # [[2. 2.]] + import paddle + from paddle.jit import to_static + + @to_static + def func(x): + if paddle.mean(x) < 0: + x_v = x - 1 + else: + x_v = x + 1 + return x_v + + x = paddle.ones([1, 2], dtype='float32') + x_v = func(x) + print(x_v) # [[2. 2.]] """