diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 348902c656cec1ea1eeaccc90feefd56d307111d..91fbfba382447d0c7019efb00ad04ea67e743df8 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -597,10 +597,57 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); +class ActivationOpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (ctx->HasOutput("DOut")) { + ctx->ShareDim("Out", "DOut"); + ctx->ShareLoD("Out", "DOut"); + } + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("Out", "DDOut"); + ctx->ShareLoD("Out", "DDOut"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "Out"); + } +}; + +// +// ReluGrad: dx = dy if y >= 0 else 0 +// ReluGradGrad: ddy = ddx if y >= 0 else 0 +// +class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker { + public: + using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { + auto* op = new ::paddle::framework::OpDesc(); + op->SetType("relu_grad_grad"); + // input1: Out + op->SetInput("Out", Input("Out")); + // X@GRAD@GRAD: ddx + op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); + op->SetAttrMap(Attrs()); + // Out@GRAD@GRAD: ddy + op->SetOutput("DOut", InputGrad("Out")); + op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + return std::unique_ptr<::paddle::framework::OpDesc>(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; #define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \ REGISTER_OPERATOR( \ @@ -632,3 +679,23 @@ namespace ops = paddle::operators; FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); + +REGISTER_OPERATOR( + relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType, + ops::ActivationGradOpDescMaker::FwdDeps()>, + paddle::framework::SingleOpInplaceInToOut); +REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, + paddle::framework::SingleOpInplaceInToOut, + ops::ReluDoubleGradMaker); +REGISTER_OPERATOR(relu_grad_grad, ops::ActivationOpDoubleGrad); + +REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); + +REGISTER_OP_CPU_KERNEL( + relu_grad_grad, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 9c7a8d8971cba4090db1bbc32c7eabf2285e7eff..20f3f3605805906424856cb345f962477f31dec2 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -32,3 +32,14 @@ namespace plat = paddle::platform; ops::grad_functor>); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); + +REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); + +REGISTER_OP_CUDA_KERNEL( + relu_grad_grad, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 1732f61582f79365d6872e15b9df1ee8f053903c..8259a392b2d44f70edd47d4132a74674856bf1a6 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1198,6 +1198,126 @@ struct SwishGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +/* + * in arguments: x, out, ddx + * out arguments: ddout, dout, dx + */ +template +inline void ExtractActivationDoubleGradTensor( + const framework::ExecutionContext& ctx, const framework::Tensor** X, + const framework::Tensor** Out, const framework::Tensor** ddX, + framework::Tensor** dX, framework::Tensor** dOut, + framework::Tensor** ddOut) { + auto out_var = ctx.InputVar("Out"); + auto ddx_var = ctx.InputVar("DDX"); + auto ddo_var = ctx.OutputVar("DDOut"); + auto do_var = ctx.OutputVar("DOut"); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get input Variable Out, variable name = %s", + ctx.op().Input("Out")); + PADDLE_ENFORCE(ddx_var != nullptr, + "Cannot get input Variable %s, variable name = %s", "DDX", + ctx.op().Input("DDX")); + if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { + *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); + *ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var); + if (ddo_var) { + *ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + ddo_var); + } + if (do_var) { + *dOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + do_var); + } + } else { + *Out = ctx.Input("Out"); + *ddX = ctx.Input("DDX"); + if (ddo_var) { + *ddOut = ctx.Output("DDOut"); + } + if (do_var) { + *dOut = ctx.Output("DOut"); + } + } + PADDLE_ENFORCE(*ddX != nullptr, + "Cannot get output tensor %s, variable name = %s", "DDX", + ctx.op().Output("DDX")); + + if (static_cast(kDepValue) & static_cast(kDepX)) { + auto x_var = ctx.InputVar("X"); + PADDLE_ENFORCE(x_var != nullptr, + "Cannot get input tensor X, variable name = %s", + ctx.op().Input("X")); + auto dx_var = ctx.OutputVar("DX"); + if (CanBeUsedBySelectedRows.count(ctx.op().Type())) { + *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); + if (dx_var) { + *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + dx_var); + } + } else { + *X = ctx.Input("X"); + if (dx_var) { + *dX = ctx.Output("DX"); + } + } + } else { + VLOG(10) << " Inplace activation of Op : " << ctx.op().Type(); + *X = *ddX; + } +} + +template +class ActivationDoubleGradKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor *X, *Out, *ddX; + X = Out = ddX = nullptr; + framework::Tensor *ddOut, *dOut, *dX; + ddOut = dOut = dX = nullptr; + + ExtractActivationDoubleGradTensor(ctx, &X, &Out, &ddX, + &dX, &dOut, &ddOut); + + if (ddOut) ddOut->mutable_data(ctx.GetPlace()); + if (dOut) dOut->mutable_data(ctx.GetPlace()); + if (dX) dX->mutable_data(Out->dims(), ctx.GetPlace()); + + auto& place = ctx.template device_context(); + + Functor functor; + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + functor(place, X, Out, ddX, ddOut, dOut, dX); + } +}; + +template +struct ReluGradGradFunctor : public BaseActivationFunctor { + template + void operator()(const Device& dev, const framework::Tensor* X, + const framework::Tensor* Out, const framework::Tensor* ddX, + framework::Tensor* ddOut, framework::Tensor* dOut, + framework::Tensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); + auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + if (ddOut) { + auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); + } + if (dOut) { + auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); + dout.device(*d) = dout.constant(static_cast(0)); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + } // namespace operators } // namespace paddle @@ -1205,7 +1325,6 @@ struct SwishGradFunctor : public BaseActivationFunctor { __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(exp, Exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, Relu, ReluFunctor, ReluGradFunctor); \ __macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \ __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 9fd53a74bf51929f9e115fdc94f2f85f8e2fbdda..9400eaadaa65b63f52513b43f76b3f06b731460d 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -611,7 +611,7 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): if inputs: for op in op_path: for name in op.desc.input_arg_names(): - if name not in input_names: + if name not in input_names and block.vars[name].stop_gradient: no_grad_set.add(name) return op_path diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ca57de3e927b26f243d262119b128fb2d1cb2f95..4821a2667586ce323ac50c8134811b6f60fbea33 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -29,7 +29,7 @@ list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957 list(REMOVE_ITEM TEST_OPS op_test) # op_test is a helper python file, not a test -list(REMOVE_ITEM TEST_OPS decorators) # decorators is a helper python file, not a test +list(REMOVE_ITEM TEST_OPS decorator_helper) # decorator_helper is a helper python file, not a test if(APPLE) if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_desc_clone) diff --git a/python/paddle/fluid/tests/unittests/decorators.py b/python/paddle/fluid/tests/unittests/decorator_helper.py similarity index 100% rename from python/paddle/fluid/tests/unittests/decorators.py rename to python/paddle/fluid/tests/unittests/decorator_helper.py diff --git a/python/paddle/fluid/tests/unittests/gradient_checker.py b/python/paddle/fluid/tests/unittests/gradient_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..14a828f28ee8141140b15afdfa7aa6f894a11b1a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/gradient_checker.py @@ -0,0 +1,351 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import six +import collections +import numpy as np +from itertools import product + +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.executor import Executor +from paddle.fluid.backward import calc_gradient +from paddle.fluid.backward import _append_grad_suffix_, _as_list + + +def _product(t): + if isinstance(t, int): + return t + else: + return np.product(t) + + +def dtype_to_np_dtype(dtype): + if dtype == core.VarDesc.VarType.FP32: + return np.float32 + elif dtype == core.VarDesc.VarType.FP64: + return np.float64 + elif dtype == core.VarDesc.VarType.FP16: + return np.float16 + else: + raise ValueError("Not supported data type " + str(dtype)) + + +def _get_item(t, i, np_dtype): + if np_dtype == np.float16: + np_t = np.array(t).astype(np.float16) + np_t = np_t.flatten() + return np_t[i] + elif np_dtype == np.float32: + return t._get_float_element(i) + elif np_dtype == np.float64: + return t._get_double_element(i) + else: + raise ValueError("Not supported data type " + str(np_dtype)) + + +def _set_item(t, i, e, np_dtype): + if np_dtype == np.float16: + np_t = np.array(t).astype(np.float16) + shape = np_t.shape + np_t = np_t.flatten() + np_t[i] = e + np_t = np_t.reshape(shape).view(np.uint16) + t.set(np_t, place) + elif np_dtype == np.float32: + t._set_float_element(i, e) + elif np_dtype == np.float64: + t._set_double_element(i, e) + else: + raise ValueError("Not supported data type " + str(np_dtype)) + + +def set_var_in_scope(scope, place, name, value, recursive_seq_len=None): + t = scope.var(name).get_tensor() + t.set(value, place) + if recursive_seq_len: + t.set_recursive_sequence_lengths(recursive_seq_len) + return t + + +def make_jacobian(x, y_size, np_dtype): + if isinstance(x, fluid.framework.Variable): + return np.zeros((_product(x.shape), y_size), dtype=np_dtype) + elif isinstance(x, collections.Sequence): + jacobians = list( + filter(lambda t: t is not None, (make_jacobian( + item, y_size, np_dtype) for item in x))) + return jacobians + else: + None + + +def _compute_numerical_jacobian(program, x, y, place, scope, delta): + """Computes the numeric Jacobian for dy/dx. + + Computes the numeric Jacobian by slightly perturbing the inputs and + measuring the differences on the output. + + Args: + program (Program): the network program. + x (Variable): the input variables. + y (list[Variable]): the output variables. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + scope (Scope): the scope used to run program. + delta: the amount of perturbation we give to the input + + Returns: + A list of 2-D numpy array, the list length is len(y). + Each 2-D numpy array represents the Jacobian for dy_i/dx. + It has "x_size" rows and "y_size" columns + where "x_size" is the number of elements in x and + "y_size" is the number of elements in each y_i. + """ + if not isinstance(x, fluid.framework.Variable): + raise TypeError('x is not Variable') + + # To compute the jacobian, treat x and y as one-dimensional vectors. + y = _as_list(y) + exe = fluid.Executor(place) + + def run(): + y_res = exe.run(program, scope=scope, fetch_list=y) + return [yi.flatten() for yi in y_res] + + x_name = x.name + x_shape = x.shape + x_size = _product(x_shape) + x_t = scope.find_var(x_name).get_tensor() + + np_type = dtype_to_np_dtype(x.dtype) + jacobian = [make_jacobian(x, _product(yi.shape), np_type) for yi in y] + + for i in six.moves.xrange(x_size): + orig = _get_item(x_t, i, np_type) + x_pos = orig + delta + _set_item(x_t, i, x_pos, np_type) + y_pos = run() + + x_neg = orig - delta + _set_item(x_t, i, x_neg, np_type) + y_neg = run() + + _set_item(x_t, i, orig, np_type) + + for j in six.moves.xrange(len(y)): + jacobian[j][i, :] = (y_pos[j] - y_neg[j]) / delta / 2. + + return jacobian + + +def _compute_analytical_jacobian(program, x, y, place, scope): + """Computes the analytical Jacobian for dy/dx. + + Args: + program (Program): a Program with forward pass. + x (Variable|list[Variable]): a variable or list of variable + y (Variable): the target variable. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + scope (Scope): the scope used to run program. + + Returns: + A list of 2-D numpy array. The list length is len(x). + Each 2-D numpy array represents the Jacobian for dy/dx_i. + It has "xi_size" rows and "dy_size" columns + where "x_size" is the number of elements in x_i and + "dy_size" is the number of elements in y. + """ + if not isinstance(y, fluid.framework.Variable): + raise TypeError('y is not Variable') + + dy_name = _append_grad_suffix_(y.name) + + np_type = dtype_to_np_dtype(y.dtype) + # create dy Variable in Program + dy = program.global_block().create_var( + name=dy_name, shape=y.shape, dtype=np_type, persistable=True) + # append backward + dx = calc_gradient(y, x, dy) + + # init dy tensor in scope + value = np.zeros(y.shape, dtype=np_type) + dy_t = set_var_in_scope(scope, place, dy_name, value) + + exe = fluid.Executor(place) + + y_size = _product(y.shape) + + x = _as_list(x) + jacobian = make_jacobian(x, y_size, np_type) + + dx = _as_list(dx) + for i in six.moves.xrange(y_size): + _set_item(dy_t, i, 1, np_type) + + dx_res = exe.run(program, scope=scope, fetch_list=dx) + + for j in six.moves.xrange(len(x)): + jacobian[j][:, i] = dx_res[j].flatten() + _set_item(dy_t, i, 0, np_type) + + return jacobian + + +def grad_check(x, + y, + x_init=None, + place=None, + program=None, + eps=1e-6, + atol=1e-5, + rtol=1e-3, + raise_exception=True): + """ + Check numerical and analytical gradients for dy/dx. + Each Jacobian gradients is a 2-D array with shape [xi_size, yi_size]. + + Args: + x (Variable|list[Variable]): input variables to the program. + y (Variable|list[Variable]): output variables to the program. + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + program (Program|None): a Program with forward pass. + If None, use fluid.default_main_program(). + eps (float): perturbation for finite differences. + atol (float): absolute tolerance. + rtol (float): relative tolerance. + raise_exception (bool): whether to raise an exception if + the check fails. Default is True. + Returns: + True if all differences satisfy numpy.allclose condition. + """ + + def fail_test(msg): + if raise_exception: + raise RuntimeError(msg) + return False + + # check input arguments + x = _as_list(x) + y = _as_list(y) + for v in x: + v.stop_gradient = False + v.persistable = True + if place is None: + place = fluid.CPUPlace() + if program is None: + program = fluid.default_main_program() + + # init variable in strtup program + scope = fluid.executor.global_scope() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + x_init = _as_list(x_init) + # init inputs if x_init is not None + if x_init: + if len(x_init) != len(x): + raise ValueError('len(x_init) (=%d) is not the same' + ' as len(x) (= %d)' % (len(x_init), len(x))) + # init variable in main program + for var, arr in zip(x, x_init): + assert var.shape == arr.shape + feeds = {k.name: v for k, v in zip(x, x_init)} + exe.run(program, feed=feeds, scope=scope) + + # [x_idx, y_idx] + numerical = [ + _compute_numerical_jacobian(program, xi, y, place, scope, eps) + for xi in x + ] + + # [y_idx, x_idx] + analytical = [ + _compute_analytical_jacobian(program, x, yi, place, scope) for yi in y + ] + + for i, (x_idx, + y_idx) in enumerate(product(*[range(len(x)), range(len(y))])): + a = analytical[y_idx][x_idx] + n = numerical[x_idx][y_idx] + if not np.allclose(a, n, rtol, atol): + msg = 'Jacobian mismatch for output %s ' \ + 'with respect to input %s on %s,\n' \ + 'numerical:%s\nanalytical:%s\n' \ + % (y[y_idx].name, x[x_idx].name, str(place), n, a) + return fail_test(msg) + return True + + +def double_grad_check(x, + y, + x_init=None, + y_grads=None, + place=None, + program=None, + eps=1e-6, + atol=1e-5, + rtol=1e-3, + raise_exception=True): + """ + Check gradients of gradients. This function will append backward to the + program before second order gradient check. + + Args: + x (Variable|list[Variable]): input variables to the program. + y (Variable|list[Variable]): output variables to the program. + x_init (numpy.array|list[numpy.array]|None): the init value for input x. + y_grads (numpy.array|list[numpy.array]|None): the gradients with respect to y. + place (fluid.CPUPlace or fluid.CUDAPlace): the device. + program (Program|None): a Program with forward pass. + If None, use fluid.default_main_program(). + eps (float): perturbation for finite differences. + atol (float): absolute tolerance. + rtol (float): relative tolerance. + raise_exception (bool): whether to raise an exception if + the check fails. Default is True. + Returns: + True if all differences satisfy numpy.allclose condition. + """ + # check input arguments + x = _as_list(x) + for v in x: + v.stop_gradient = False + v.persistable = True + y = _as_list(y) + + if program is None: + program = fluid.default_main_program() + + if y_grads is None: + scope = fluid.executor.global_scope() + y_grads = [] + for yi in y: + dyi_name = _append_grad_suffix_(yi.name) + np_type = dtype_to_np_dtype(yi.dtype) + dy = program.global_block().create_var( + name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True) + dy.stop_gradient = False + v = np.random.random(size=yi.shape).astype(np_type) + set_var_in_scope(scope, place, dyi_name, v) + y_grads.append(dy) + else: + y_grads = _as_list(y_grads) + + # append first order grads + target_grads = calc_gradient(y, x, y_grads) + grad_check(x, target_grads, x_init, place, program, eps, atol, rtol) diff --git a/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py b/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py index 9d635f36fe83d041bb57df0759da1481f66bbaa2..5328f73b31513745a4ddd51044bea7b3f59eaf5f 100644 --- a/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/test_dynrnn_gradient_check.py @@ -19,7 +19,7 @@ import random import collections import paddle.fluid as fluid import unittest -from decorators import * +from decorator_helper import * class Memory(object): diff --git a/python/paddle/fluid/tests/unittests/test_get_places_op.py b/python/paddle/fluid/tests/unittests/test_get_places_op.py index 441666a97b16a320692d6a15363f61156e52242b..e6be3a3a3e5b6ae7570d2ebdf2836e48345f5734 100644 --- a/python/paddle/fluid/tests/unittests/test_get_places_op.py +++ b/python/paddle/fluid/tests/unittests/test_get_places_op.py @@ -16,12 +16,12 @@ from __future__ import print_function import paddle.fluid as fluid from paddle.fluid.layers.device import get_places -import decorators +from decorator_helper import prog_scope import unittest class TestGetPlaces(unittest.TestCase): - @decorators.prog_scope() + @prog_scope() def test_get_places(self): places = get_places() cpu = fluid.CPUPlace() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 91f8bc5fd0a510dcc05cb7ba2397cad52be16af5..46f025c33bc9cc3a7197a4e87475b4d9c132b4ed 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -17,7 +17,7 @@ import unittest import contextlib import numpy as np -import decorators +from decorator_helper import prog_scope import inspect from six.moves import filter @@ -1171,7 +1171,7 @@ class TestBook(LayerTest): fluid.default_startup_program()): get_places(device_count=1) - @decorators.prog_scope() + @prog_scope() def make_nce(self): window_size = 5 words = [] diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index b25d40a3a15e259878222ee5482cd842543b63d6..f6cdb17def9e472414bf1213d8756f6d2977adfa 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -15,13 +15,13 @@ from __future__ import print_function import unittest -import decorators +from decorator_helper import prog_scope import paddle.fluid as fluid import numpy class TestMathOpPatches(unittest.TestCase): - @decorators.prog_scope() + @prog_scope() def test_add_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = a + 10 @@ -41,7 +41,7 @@ class TestMathOpPatches(unittest.TestCase): d_expected = ab_np + numpy.concatenate([a_np, a_np], axis=1) self.assertTrue(numpy.allclose(d_expected, d_np)) - @decorators.prog_scope() + @prog_scope() def test_radd_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = 10 + a @@ -53,7 +53,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(a_np + 10, b_np)) - @decorators.prog_scope() + @prog_scope() def test_sub_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = a - 10 @@ -65,7 +65,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(a_np - 10, b_np)) - @decorators.prog_scope() + @prog_scope() def test_radd_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = 10 - a @@ -77,7 +77,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(10 - a_np, b_np)) - @decorators.prog_scope() + @prog_scope() def test_mul_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = a * 10 @@ -89,7 +89,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(a_np * 10, b_np)) - @decorators.prog_scope() + @prog_scope() def test_rmul_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = 10 * a @@ -101,7 +101,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(10 * a_np, b_np)) - @decorators.prog_scope() + @prog_scope() def test_div_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = a / 10 @@ -113,7 +113,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(a_np / 10, b_np)) - @decorators.prog_scope() + @prog_scope() def test_rdiv_scalar(self): a = fluid.layers.data(name="a", shape=[1]) b = 10 / a @@ -126,7 +126,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(10 / a_np, b_np)) - @decorators.prog_scope() + @prog_scope() def test_div_two_tensor(self): a = fluid.layers.data(name="a", shape=[1]) b = fluid.layers.data(name="b", shape=[1]) @@ -141,7 +141,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[c]) self.assertTrue(numpy.allclose(a_np / b_np, c_np)) - @decorators.prog_scope() + @prog_scope() def test_mul_two_tensor(self): a = fluid.layers.data(name="a", shape=[1]) b = fluid.layers.data(name="b", shape=[1]) @@ -156,7 +156,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[c]) self.assertTrue(numpy.allclose(a_np * b_np, c_np)) - @decorators.prog_scope() + @prog_scope() def test_add_two_tensor(self): a = fluid.layers.data(name="a", shape=[1]) b = fluid.layers.data(name="b", shape=[1]) @@ -171,7 +171,7 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[c]) self.assertTrue(numpy.allclose(a_np + b_np, c_np)) - @decorators.prog_scope() + @prog_scope() def test_sub_two_tensor(self): a = fluid.layers.data(name="a", shape=[1]) b = fluid.layers.data(name="b", shape=[1]) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f26386c92e11b1486fdc03f1fab0c16528014d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +import gradient_checker + +from decorator_helper import prog_scope + + +class TestMulGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + prog = fluid.Program() + with fluid.program_guard(prog): + x = layers.create_parameter(dtype="float64", shape=[2, 8], name='x') + y = layers.create_parameter(dtype="float64", shape=[8, 4], name='y') + z = layers.mul(x=x, y=y) + gradient_checker.grad_check([x, y], z, place=place) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestReluDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable shoule be clearly specified, not inlcude -1. + shape = [2, 8] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + x.persistable = True + y = layers.relu(x) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.02 + + gradient_checker.double_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_registry.py b/python/paddle/fluid/tests/unittests/test_registry.py index 7381bb61eb4630cb67bc306fde211704e9580af4..39cf64465ab1ed618ef4e63e1b9d7787d419f3d8 100644 --- a/python/paddle/fluid/tests/unittests/test_registry.py +++ b/python/paddle/fluid/tests/unittests/test_registry.py @@ -17,11 +17,11 @@ import unittest import paddle.fluid as fluid import numpy as np -import decorators +from decorator_helper import prog_scope class TestRegistry(unittest.TestCase): - @decorators.prog_scope() + @prog_scope() def test_registry_layer(self): x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32') output = fluid.layers.mean(x)