未验证 提交 808be657 编写于 作者: J Jiabin Yang 提交者: GitHub

[New Feature] Support tanh triple grad (#36225)

* native commit for triple grad of sigmod

* Updated unittests files

* init functional jacobian api

* Updated trible_test func

* Updated gradient_checker & test_script

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* fix dygraph grad to support high differential

* polish API docstring

* Updated gradient checker and some related files

* fix double grad strip error for high differential

* fix double grad strip error for high differential

* Add Sigmoid triple grad tests

* fix dygraph double grad dtype error when calling for high differential senario

* Updated triple grad teses func

* Use np.random to initialize ddx

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* add tanh triple grad

* format python code

* refine code
Co-authored-by: Nveyron95 <veyron_wu@163.com>
Co-authored-by: Nlevi131 <limaolin01@baidu.com>
上级 4dda18a8
......@@ -940,6 +940,34 @@ class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class TanhTripleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tanh_triple_grad");
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
// D_OutNew, D_DOut, D_DDx // output
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input2: ddx
op->SetInput("DDX", this->Input("DDX"));
// input3: dout
op->SetInput("DOut", this->Input("DOut"));
// input4: d_ddout
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
// input5: d_dout_new
op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
op->SetAttrMap(this->Attrs());
// output: d_dOut, d_OutNew, d_ddx
op->SetOutput("D_OutNew", this->InputGrad("Out"));
op->SetOutput("D_DOut", this->InputGrad("DOut"));
op->SetOutput("D_DDx", this->InputGrad("DDX"));
}
};
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
template <typename T>
......@@ -1299,7 +1327,14 @@ REGISTER_OPERATOR(tanh_grad, ops::ActivationOpGrad,
REGISTER_OPERATOR(
tanh_grad_grad,
ops::ActivationOpDoubleGrad<ops::TanhGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
ops::ActivationDoubleGradOpInplaceInferer,
ops::TanhTripleGradMaker<paddle::framework::OpDesc>,
ops::TanhTripleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
tanh_triple_grad,
ops::ActivationOpTripleGrad<ops::TanhTripleGradFunctor<float>::FwdDeps()>,
ops::ActivationTripleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(tanh, Tanh, TanhFunctor, TanhGradFunctor);
REGISTER_OP_CPU_KERNEL(
......@@ -1309,6 +1344,15 @@ REGISTER_OP_CPU_KERNEL(
ops::TanhGradGradFunctor<double>>,
ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
ops::TanhGradGradFunctor<plat::float16>>);
// Register TripleGrad Kernel
REGISTER_OP_CPU_KERNEL(
tanh_triple_grad,
ops::TanhTripeGradKernel<plat::CPUDeviceContext,
ops::TanhTripleGradFunctor<float>>,
ops::TanhTripeGradKernel<plat::CPUDeviceContext,
ops::TanhTripleGradFunctor<double>>,
ops::TanhTripeGradKernel<plat::CPUDeviceContext,
ops::TanhTripleGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== relu register ============================= */
......
......@@ -1487,6 +1487,15 @@ REGISTER_OP_CUDA_KERNEL(
ops::TanhGradGradFunctor<double>>,
ops::TanhDoubleGradKernel<plat::CUDADeviceContext,
ops::TanhGradGradFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
tanh_triple_grad,
ops::TanhTripeGradKernel<paddle::platform::CUDADeviceContext,
ops::TanhTripleGradFunctor<float>>,
ops::TanhTripeGradKernel<paddle::platform::CUDADeviceContext,
ops::TanhTripleGradFunctor<double>>,
ops::TanhTripeGradKernel<plat::CUDADeviceContext,
ops::TanhTripleGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
......
......@@ -536,6 +536,61 @@ struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
/*
Out
DOut D_Dout
DDx -> TanhTripleGrad -> D_DDx
D_DDout d_OutNew
D_Dout_new
D_Dout = (-2) * Out * DDx * D_Dout_new
D_DDx = (1-Out^2)*D_DDout + (-2) * Out * DOut * D_Dout_new
D_OutNew = (-2) * Out * DDx * D_DDout + (-2) * DOut * DDx * D_Dout_new
Out, DDX, DOut, D_DDOut, D_DOut_New // input
D_OutNew, D_DOut, D_DDx // output
*/
template <typename T>
struct TanhTripleGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, const framework::Tensor* dOut,
const framework::Tensor* d_DDOut,
const framework::Tensor* d_dOut_New,
framework::Tensor* d_d_Out, framework::Tensor* d_Out_New,
framework::Tensor* d_DDx) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhTripleGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "TanhTripleGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhTripleGrad"));
auto d_ddOut = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "D_DDOut", "TanhTripleGrad"));
auto d_dOutNew = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dOut_New, "Input", "D_DOut_New", "TanhTripleGrad"));
if (d_Out_New) {
auto d_OutNew = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_Out_New, "Output", "D_OutNew", "TanhTripleGrad"));
d_OutNew.device(*d) = (static_cast<T>(-2) * out * ddx * d_ddOut) -
(static_cast<T>(2) * dout * ddx * d_dOutNew);
}
if (d_d_Out) {
auto d_dOut = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "D_DOut", "TanhTripleGrad"));
d_dOut.device(*d) = static_cast<T>(-2) * out * ddx * d_dOutNew;
}
if (d_DDx) {
auto d_ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "D_DDx", "TanhTripleGrad"));
d_ddx.device(*d) = (static_cast<T>(1) - (out * out)) * d_ddOut -
static_cast<T>(2) * out * dout * d_dOutNew;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
......@@ -2137,6 +2192,63 @@ class TanhDoubleGradKernel
functor(place, Out, ddX, dOut, dOutNew, ddOut);
}
};
template <typename DeviceContext, typename Functor>
class TanhTripeGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *ddX, *dOut, *d_ddOut, *d_dOutNew;
framework::Tensor *d_OutNew, *d_dOut, *d_ddx;
Out = ddX = dOut = d_ddOut = d_dOutNew = nullptr;
d_OutNew = d_dOut = d_ddx = nullptr;
// extract ddx(input), out(input), dOut(input), d_ddOut(input),
// d_dOutNew(input)
ddX = ctx.Input<framework::Tensor>("DDX");
Out = ctx.Input<framework::Tensor>("Out");
dOut = ctx.Input<framework::Tensor>("DOut");
d_ddOut = ctx.Input<framework::Tensor>("D_DDOut");
d_dOutNew = ctx.Input<framework::Tensor>("D_DOut_New");
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable ddX, variable name = %s",
ctx.InputName("DDX")));
PADDLE_ENFORCE_NOT_NULL(
Out, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
PADDLE_ENFORCE_NOT_NULL(
dOut, platform::errors::NotFound(
"Cannot get input Variable dOut, variable name = %s",
ctx.InputName("DOut")));
PADDLE_ENFORCE_NOT_NULL(
d_ddOut, platform::errors::NotFound(
"Cannot get input Variable d_ddOut, variable name = %s",
ctx.InputName("D_DDOut")));
PADDLE_ENFORCE_NOT_NULL(
d_dOutNew,
platform::errors::NotFound(
"Cannot get input Variable d_dOutNew, variable name = %s",
ctx.InputName("D_DOutNew")));
// set output d_OutNew、d_dOut、d_ddx
d_dOut = ctx.Output<framework::Tensor>("D_DOut");
d_OutNew = ctx.Output<framework::Tensor>("D_OutNew");
d_ddx = ctx.Output<framework::Tensor>("D_DDx");
if (d_dOut) d_dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (d_OutNew) d_OutNew->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (d_ddx) d_ddx->mutable_data<T>(ddX->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, dOut, d_ddOut, d_dOutNew, // input
d_dOut, d_OutNew, d_ddx); // output
}
};
template <typename DeviceContext, typename Functor>
class SquareDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......
......@@ -71,6 +71,28 @@ class TestSigmoidDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestTanhTripleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0005
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = layers.tanh(x)
x_arr = np.random.random(shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
gradient_checker.triple_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestTanhDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册