未验证 提交 2c44ee7e 编写于 作者: J Jiabin Yang 提交者: GitHub

[New Feature] Support triple grad in Paddle (#36187)

* native commit for triple grad of sigmod

* Updated unittests files

* init functional jacobian api

* Updated trible_test func

* Updated gradient_checker & test_script

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* fix dygraph grad to support high differential

* polish API docstring

* Updated gradient checker and some related files

* fix double grad strip error for high differential

* fix double grad strip error for high differential

* Add Sigmoid triple grad tests

* fix dygraph double grad dtype error when calling for high differential senario

* Updated triple grad teses func

* Use np.random to initialize ddx

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code
Co-authored-by: Nveyron95 <veyron_wu@163.com>
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
上级 d7858c99
......@@ -77,12 +77,12 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
FLAGS_use_mkldnn ||
(op->HasAttr("use_mkldnn") &&
BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
op->SetInput("X", this->Input("X"));
op->SetInput("X", this->Input("X")); // x
}
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", this->Output("Out"));
op->SetInput("Out", this->Output("Out")); // out
}
}
};
......@@ -767,6 +767,10 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
if (ctx->HasOutput("DOutNew")) {
ctx->ShareDim("Out", "DOutNew");
ctx->ShareLoD("Out", "DOutNew");
}
}
}
......@@ -804,6 +808,45 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("D_DOut")) {
ctx->ShareDim("Out", "D_DOut");
ctx->ShareLoD("Out", "D_DOut");
}
if (ctx->HasOutput("D_OutNew")) {
ctx->ShareDim("Out", "D_OutNew");
ctx->ShareLoD("Out", "D_OutNew");
}
if (ctx->HasOutput("D_DDx")) {
ctx->ShareDim("DDX", "D_DDx");
ctx->ShareLoD("DDX", "D_DDx");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
template <typename T>
class SigmoidDoubleGradMaker
: public ::paddle::framework::SingleGradOpMaker<T> {
......@@ -825,6 +868,36 @@ class SigmoidDoubleGradMaker
}
};
template <typename T>
class SigmoidTripleGradMaker
: public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("sigmoid_triple_grad");
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
// D_OutNew, D_DOut, D_DDx // output
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input2: ddx
op->SetInput("DDX", this->Input("DDX"));
// input3: dout
op->SetInput("DOut", this->Input("DOut"));
// input4: d_ddout
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
// input5: d_dout_new
op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
op->SetAttrMap(this->Attrs());
// output: d_dOut, d_OutNew, d_ddx
op->SetOutput("D_OutNew", this->InputGrad("Out"));
op->SetOutput("D_DOut", this->InputGrad("DOut"));
op->SetOutput("D_DDx", this->InputGrad("DDX"));
}
};
template <typename T>
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
......@@ -995,10 +1068,12 @@ class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
};
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
{framework::GradVarName("Out"), // dout
framework::GradVarName("X")}); // dx
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
{"DDX", "DDOut"});
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
{"DDX", "D_DOut"});
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -1121,13 +1196,21 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(sigmoid_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>)
ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>);
// 3. Register Sigmoid DoubleGrad Operator
REGISTER_OPERATOR(
sigmoid_grad_grad,
ops::ActivationOpDoubleGrad<ops::SigmoidGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer,
ops::SigmoidTripleGradMaker<paddle::framework::OpDesc>,
ops::SigmoidTripleGradMaker<paddle::imperative::OpBase>);
// 4. Register Sigmoid TripleGrad Operator
REGISTER_OPERATOR(sigmoid_triple_grad,
ops::ActivationOpTripleGrad<
ops::SigmoidTripleGradFunctor<float>::FwdDeps()>,
ops::ActivationTripleGradOpInplaceInferer);
// Register Sigmoid/GradSigmoid Kernels
REGISTER_ACTIVATION_CPU_KERNEL(sigmoid, Sigmoid, SigmoidFunctor,
......@@ -1143,6 +1226,16 @@ REGISTER_OP_CPU_KERNEL(
ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
ops::SigmoidGradGradFunctor<plat::float16>>);
// Register TripleGrad Kernel
REGISTER_OP_CPU_KERNEL(
sigmoid_triple_grad,
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
ops::SigmoidTripleGradFunctor<float>>,
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
ops::SigmoidTripleGradFunctor<double>>,
ops::SigmoidTripleGradKernel<plat::CPUDeviceContext,
ops::SigmoidTripleGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== tanh register ============================= */
......
......@@ -1398,6 +1398,15 @@ REGISTER_OP_CUDA_KERNEL(
ops::SigmoidGradGradFunctor<double>>,
ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
ops::SigmoidGradGradFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
sigmoid_triple_grad,
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<float>>,
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<double>>,
ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== tanh register ============================ */
......
......@@ -24,12 +24,13 @@ limitations under the License. */
#define _USE_MATH_DEFINES
#endif
#include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -282,19 +283,77 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> {
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidGradGrad"));
auto dout_new = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SquareGradGrad"));
GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SigmoidGradGrad"));
dout_new.device(*d) =
(static_cast<T>(1) - static_cast<T>(2) * out) * dout * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SigmoidGradGrad"));
ddout.device(*d) = (static_cast<T>(1) - out) * out * ddx;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
/*
Out
DOut D_Dout
DDx -> SigmoidTripleGrad -> D_DDx
D_DDout d_OutNew
D_Dout_new
D_Dout = (1-2*Out)*DDx*D_Dout_new
D_DDx = (1-Out)*Out*D_DDout + (1-2*Out)*DOut*D_Dout_new
D_OutNew = (DDx-2*Out*DDx)*D_DDout - 2*DOut*DDx*D_Dout_new
Out, DDX, DOut, D_DDOut, D_DOut_New // input
D_OutNew, D_DOut, D_DDx // output
*/
template <typename T>
struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, const framework::Tensor* dOut,
const framework::Tensor* d_DDOut,
const framework::Tensor* d_dOut_New,
framework::Tensor* d_d_Out, framework::Tensor* d_Out_New,
framework::Tensor* d_DDx) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SigmoidTripleGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "SigmoidTripleGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "SigmoidTripleGrad"));
auto d_ddOut = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "SigmoidTripleGrad"));
auto d_dOutNew = framework::EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_dOut_New, "Input", "D_DOut_New", "SigmoidTripleGrad"));
if (d_Out_New) {
auto d_OutNew = framework::EigenVector<T>::Flatten(GET_DATA_SAFELY(
d_Out_New, "Output", "D_OutNew", "SigmoidTripleGrad"));
d_OutNew.device(*d) = (ddx - static_cast<T>(2) * out * ddx) * d_ddOut -
static_cast<T>(2) * dout * ddx * d_dOutNew;
}
if (d_d_Out) {
auto d_dOut = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "SigmoidTripleGrad"));
d_dOut.device(*d) =
(static_cast<T>(1) - static_cast<T>(2) * out) * ddx * d_dOutNew;
}
if (d_DDx) {
auto d_ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "SigmoidTripleGrad"));
d_ddx.device(*d) =
(static_cast<T>(1) - out) * out * d_ddOut +
(static_cast<T>(1) - static_cast<T>(2) * out) * dout * d_dOutNew;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// silu(x) = x / (1 + exp(-x))
template <typename T>
struct SiluFunctor : public BaseActivationFunctor<T> {
......@@ -465,13 +524,13 @@ struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhGradGrad"));
auto dout_new = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SquareGradGrad"));
GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "TanhGradGrad"));
dout_new.device(*d) =
static_cast<T>(-1) * dout * static_cast<T>(2) * out * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "TanhGradGrad"));
ddout.device(*d) = (static_cast<T>(1) - out * out) * ddx;
}
}
......@@ -1856,7 +1915,6 @@ class SigmoidDoubleGradKernel
framework::Tensor *dOutNew, *ddOut;
Out = ddX = dOut = nullptr;
dOutNew = ddOut = nullptr;
// extract ddx(input) and out(input)
ddX = ctx.Input<framework::Tensor>("DDX");
Out = ctx.Input<framework::Tensor>("Out");
......@@ -1868,20 +1926,15 @@ class SigmoidDoubleGradKernel
Out, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
// set output ddout
ddOut = ctx.Output<framework::Tensor>("DDOut");
// extract dOut(intput)
dOut = ctx.Input<framework::Tensor>("DOut");
PADDLE_ENFORCE_NOT_NULL(
dOut, platform::errors::NotFound(
"Cannot get input Variable dOut, variable name = %s",
ctx.InputName("DOut")));
// set output dout_new
dOutNew = ctx.Output<framework::Tensor>("DOutNew");
if (dOutNew) dOutNew->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
......@@ -1890,6 +1943,64 @@ class SigmoidDoubleGradKernel
}
};
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
// D_OutNew, D_DOut, D_DDx // output
template <typename DeviceContext, typename Functor>
class SigmoidTripleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *ddX, *dOut, *d_ddOut, *d_dOutNew;
framework::Tensor *d_OutNew, *d_dOut, *d_ddx;
Out = ddX = dOut = d_ddOut = d_dOutNew = nullptr;
d_OutNew = d_dOut = d_ddx = nullptr;
// extract ddx(input), out(input), dOut(input), d_ddOut(input),
// d_dOutNew(input)
ddX = ctx.Input<framework::Tensor>("DDX");
Out = ctx.Input<framework::Tensor>("Out");
dOut = ctx.Input<framework::Tensor>("DOut");
d_ddOut = ctx.Input<framework::Tensor>("D_DDOut");
d_dOutNew = ctx.Input<framework::Tensor>("D_DOut_New");
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable ddX, variable name = %s",
ctx.InputName("DDX")));
PADDLE_ENFORCE_NOT_NULL(
Out, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
PADDLE_ENFORCE_NOT_NULL(
dOut, platform::errors::NotFound(
"Cannot get input Variable dOut, variable name = %s",
ctx.InputName("DOut")));
PADDLE_ENFORCE_NOT_NULL(
d_ddOut, platform::errors::NotFound(
"Cannot get input Variable d_ddOut, variable name = %s",
ctx.InputName("D_DDOut")));
PADDLE_ENFORCE_NOT_NULL(
d_dOutNew,
platform::errors::NotFound(
"Cannot get input Variable d_dOutNew, variable name = %s",
ctx.InputName("D_DOutNew")));
// set output d_OutNew、d_dOut、d_ddx
d_dOut = ctx.Output<framework::Tensor>("D_DOut");
d_OutNew = ctx.Output<framework::Tensor>("D_OutNew");
d_ddx = ctx.Output<framework::Tensor>("D_DDx");
if (d_dOut) d_dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (d_OutNew) d_OutNew->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (d_ddx) d_ddx->mutable_data<T>(ddX->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, dOut, d_ddOut, d_dOutNew, // input
d_dOut, d_OutNew, d_ddx); // output
}
};
template <typename DeviceContext, typename Functor>
class TanhDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......
......@@ -27,6 +27,7 @@ from . import unique_name
from . import log_helper
import paddle.fluid
from .data_feeder import check_type
import warnings
__all__ = [
'append_backward',
'gradients',
......@@ -371,6 +372,10 @@ def _infer_var_data_type_shape_(grad_var_name, block):
grad_var.set_dtype(fwd_var.dtype())
grad_var.set_shape(fwd_var.shape())
else:
# TODO(jiabin): Maybe we should not to this to cause some unexpected error on dtype
warnings.warn(
"Set grad var: {} dtype to default FP32, since we can't find its related forward var".
format(grad_var_name))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
......@@ -408,7 +413,9 @@ def _strip_grad_suffix_(name):
"""
name = cpt.to_text(name)
pos = name.find(core.grad_var_suffix())
return name[:pos] if pos != -1 else name
new_name = name[:pos] if pos != -1 else name
new_pos = name.rfind('grad/')
return new_name[new_pos + 5:] if new_pos != -1 else new_name
def _append_grad_suffix_(name):
......
......@@ -309,7 +309,7 @@ def grad_check(x,
_compute_analytical_jacobian(prog, clone_x, clone_y, place, scope))
for i, (x_idx,
y_idx) in enumerate(product(* [range(len(x)), range(len(y))])):
y_idx) in enumerate(product(*[range(len(x)), range(len(y))])):
a = analytical[y_idx][x_idx]
n = numerical[x_idx][y_idx]
if not np.allclose(a, n, rtol, atol):
......@@ -391,3 +391,118 @@ def double_grad_check(x,
x_init += y_grads_init
grad_check(x, target_grads, x_init, place, program, eps, atol, rtol)
# TODO(jiabin): We currently support only triple grad check here, extend this to support
# higher order differenciation later.
# check triple grad and two outputs of the triple Kernel
def triple_grad_check(x,
y,
x_init=None,
y_grads=None,
x_grads_grads=None,
place=None,
program=None,
eps=1e-6,
atol=1e-5,
rtol=1e-3,
raise_exception=True):
"""
Check triple gradients. This function will append backward to the
program before third order gradient check.
Args:
x (Variable|list[Variable]): input variables to the program.
y (Variable|list[Variable]): output variables to the program.
x_init (numpy.array|list[numpy.array]|None): the init value for input x.
y_grads (numpy.array|list[numpy.array]|None): the gradients with respect to y.
x_grads_grads (numpy.array|list[numpy.array]|None): the gradients with respect to your input.
place (fluid.CPUPlace or fluid.CUDAPlace): the device.
program (Program|None): a Program with forward pass.
If None, use fluid.default_main_program().
eps (float): perturbation for finite differences.
atol (float): absolute tolerance.
rtol (float): relative tolerance.
raise_exception (bool): whether to raise an exception if
the check fails. Default is True.
Returns:
True if all differences satisfy numpy.allclose condition.
"""
# check input arguments
x = _as_list(x)
for v in x:
v.stop_gradient = False
v.persistable = True
y = _as_list(y)
if program is None:
program = fluid.default_main_program()
if y_grads is None:
scope = fluid.executor.global_scope()
y_grads = []
y_grads_init = []
for yi in y:
dyi_name = _append_grad_suffix_(yi.name)
np_type = dtype_to_np_dtype(yi.dtype)
dy = program.global_block().create_var(
name=dyi_name, shape=yi.shape, dtype=np_type, persistable=True)
dy.stop_gradient = False
v = np.random.random(size=yi.shape).astype(np_type)
set_var_in_scope(scope, place, dyi_name, v)
y_grads.append(dy)
y_grads_init.append(v)
else:
y_grads = _as_list(y_grads)
y_grads_init = [
var_to_np_array_in_scope(scope, place, v.name) for v in y_grads
]
# append first order grads
target_grads = fluid.gradients(y, x, y_grads)
if x_grads_grads is None:
scope = fluid.executor.global_scope()
x_grads_grads = []
x_grads_grads_init = []
for dxi in target_grads:
ddxi_name = _append_grad_suffix_(dxi.name)
np_type = dtype_to_np_dtype(dxi.dtype)
ddx = program.global_block().create_var(
name=ddxi_name,
shape=dxi.shape,
dtype=np_type,
persistable=True)
ddx.stop_gradient = False
v = np.random.random(size=dxi.shape).astype(np_type)
set_var_in_scope(scope, place, ddxi_name, v)
x_grads_grads.append(ddx)
x_grads_grads_init.append(v)
else:
x_grads_grads = _as_list(x_grads_grads)
x_grads_grads_init = [
var_to_np_array_in_scope(scope, place, v.name)
for v in x_grads_grads
]
# append second order grads
target_grads_grads = fluid.gradients(target_grads, x, x_grads_grads)
x += y_grads
x_init = _as_list(x_init)
x_init += y_grads_init
x += x_grads_grads
x_init += x_grads_grads_init
# x <=> [x, dout, ddx]
grad_check(
x=x,
y=target_grads_grads,
x_init=x_init,
place=place,
program=program,
eps=eps,
atol=atol,
rtol=rtol)
......@@ -26,6 +26,28 @@ import gradient_checker
from decorator_helper import prog_scope
class TestSigmoidTripleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0005
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = layers.sigmoid(x)
x_arr = np.random.random(shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
gradient_checker.triple_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestSigmoidDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from decorator_helper import prog_scope
import unittest
import paddle.fluid as fluid
import numpy as np
import paddle
import warnings
class TestBackwardInferVarDataTypeShape(unittest.TestCase):
def test_backward_infer_var_data_type_shape(self):
paddle.enable_static()
program = fluid.default_main_program()
dy = program.global_block().create_var(
name="Tmp@GRAD", shape=[1, 1], dtype=np.float32, persistable=True)
# invoke warning
fluid.backward._infer_var_data_type_shape_("Tmp@GRAD",
program.global_block())
res = False
with warnings.catch_warnings():
res = True
self.assertTrue(res)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册