未验证 提交 1a145aab 编写于 作者: C cyber-pioneer 提交者: GitHub

add cos double and triple grad operator (#47796)

上级 42c8d51a
......@@ -38,6 +38,8 @@ ops_to_fill_zero_for_empty_grads = set(
"tanh_triple_grad",
"sin_double_grad",
"sin_triple_grad",
"cos_double_grad",
"cos_triple_grad",
"subtract_double_grad",
"divide_double_grad",
"log_double_grad",
......
......@@ -172,6 +172,18 @@
kernel :
func : cholesky_solve_grad
- backward_op : cos_double_grad
forward : cos_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad)
output : Tensor(x_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, x]
kernel :
func : cos_double_grad
backward : cos_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_op : cos_grad
forward : cos (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -181,8 +193,20 @@
param : [x]
kernel :
func : cos_grad
backward : cos_double_grad
inplace : (out_grad -> x_grad)
- backward_op : cos_triple_grad
forward : cos_double_grad (Tensor x, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_x), Tensor(grad_out_grad)
args : (Tensor x, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_x_grad, Tensor grad_out_grad_grad)
output : Tensor(x_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, x, grad_x_grad_forward]
kernel :
func : cos_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : cosh_grad
forward : cosh (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -229,7 +229,7 @@
attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = platform::GetDefaultConvWorkspaceSizeLimitMB()]
- op : cos
backward : cos_grad
backward : cos_grad, cos_double_grad, cos_triple_grad
inputs :
x : X
outputs :
......
......@@ -88,6 +88,14 @@ void SinDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context>
void CosDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context>
void TanhDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
......@@ -118,6 +126,17 @@ void SinTripleGradKernel(const Context& dev_ctx,
DenseTensor* d_dout,
DenseTensor* d_ddx);
template <typename T, typename Context>
void CosTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& d_dx_new,
const DenseTensor& d_ddout,
DenseTensor* d_x_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
template <typename T, typename Context>
void LeakyReluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -336,6 +336,7 @@ PD_REGISTER_KERNEL(square_double_grad,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_KERNEL(sin_double_grad,
CPU,
ALL_LAYOUT,
......@@ -345,6 +346,7 @@ PD_REGISTER_KERNEL(sin_double_grad,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_KERNEL(sin_triple_grad,
CPU,
ALL_LAYOUT,
......@@ -354,6 +356,27 @@ PD_REGISTER_KERNEL(sin_triple_grad,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_KERNEL(cos_double_grad,
CPU,
ALL_LAYOUT,
phi::CosDoubleGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_KERNEL(cos_triple_grad,
CPU,
ALL_LAYOUT,
phi::CosTripleGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
......
......@@ -117,23 +117,22 @@ struct SinDoubleGradFunctor : public BaseActivationFunctor<T> {
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SinDoubleGrad"));
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "SinDoubleGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SinDoubleGrad"));
// sin DoubleGrad: ddy=cos(x)*ddx, dx=-sin(x)*dy*ddx
GET_DATA_SAFELY(X, "Input", "x", "SinDoubleGrad"));
// calculate dx first, so ddy can inplace ddx
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SinDoubleGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SinDoubleGrad"));
dx.device(*d) = -ddx * x.unaryExpr(Sine<T>()) * dout;
// calculate d2x first, so d2d1y can inplace d2d1x
auto d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "d2x", "SinDoubleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "d1y", "SinDoubleGrad"));
d2x.device(*d) = -d2d1x * x.unaryExpr(Sine<T>()) * d1y;
// calculate ddout
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SinDoubleGrad"));
ddout.device(*d) = ddx * x.unaryExpr(Cosine<T>());
// calculate d2d1y
auto d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "SinDoubleGrad"));
d2d1y.device(*d) = d2d1x * x.unaryExpr(Cosine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
......@@ -221,6 +220,22 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};
// 1st reverse grad
// y = cos(x)
// x --> y
// d1x = d1y * -sin(x)
//
// 2nd reverse grad
// x, d1y --> d1x
// d2x = -cos(x) * d1y * d2d1x
// d2d1y = -sin(x) * d2d1x
//
// 3rd reverse grad
// x, d1y, d2d1x --> d2x, d2d1y
// d3x = sin(x) * d1y * d2d1x * d3d2x - cos(x) * d2d1x * d3d2d1y
// d3d1y = -cos(x) * d2d1x * d3d2x
// d3d2d1x = -cos(x) * d1y * d3d2x - sin(x) * d3d2d1y
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
......@@ -236,6 +251,80 @@ struct CosGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cos''(x) = -cos(x)
template <typename T>
struct CosDoubleGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* dOut,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosDoubleGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "x", "CosDoubleGrad"));
// calculate d2x first, so d2d1y can inplace d2d1x
auto d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "d2x", "CosDoubleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "d1y", "CosDoubleGrad"));
d2x.device(*d) = -d2d1x * x.unaryExpr(Cosine<T>()) * d1y;
// calculate d2d1y
auto d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "d2d1y", "CosDoubleGrad"));
d2d1y.device(*d) = -d2d1x * x.unaryExpr(Sine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CosTripleGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* ddX,
const DenseTensor* dOut,
const DenseTensor* d_DDOut,
const DenseTensor* d_dx_New,
DenseTensor* d_d_Out,
DenseTensor* d_x_New,
DenseTensor* d_DDx) const {
auto* d = dev.eigen_device();
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "x", "CosTripleGrad"));
auto d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "d2d1x", "CosTripleGrad"));
auto d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "d1y", "CosTripleGrad"));
auto d3d2d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDOut, "Input", "d3d2d1y", "CosTripleGrad"));
auto d3d2x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_dx_New, "Input", "d3d2x", "CosTripleGrad"));
auto d3x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_x_New, "Output", "d3x", "CosTripleGrad"));
d3x.device(*d) = x.unaryExpr(Sine<T>()) * d1y * d2d1x * d3d2x -
x.unaryExpr(Cosine<T>()) * d2d1x * d3d2d1y;
auto d3d1y = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_d_Out, "Output", "d3d1y", "CosTripleGrad"));
d3d1y.device(*d) = -x.unaryExpr(Cosine<T>()) * d2d1x * d3d2x;
auto d3d2d1x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(d_DDx, "Output", "d3d2d1x", "CosTripleGrad"));
d3d2d1x.device(*d) = -x.unaryExpr(Cosine<T>()) * d1y * d3d2x -
x.unaryExpr(Sine<T>()) * d3d2d1y;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
......
......@@ -437,6 +437,26 @@ PD_REGISTER_KERNEL(sin_triple_grad,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(cos_double_grad,
GPU,
ALL_LAYOUT,
phi::CosDoubleGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(cos_triple_grad,
GPU,
ALL_LAYOUT,
phi::CosTripleGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
......
......@@ -646,4 +646,56 @@ void SinTripleGradKernel(const Context& dev_ctx,
d_ddx); // output
}
template <typename T, typename Context>
void CosDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::CosDoubleGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
}
template <typename T, typename Context>
void CosTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& d_dx_new,
const DenseTensor& d_ddout,
DenseTensor* d_x_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
if (d_dout) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_dout);
}
if (d_x_new) {
d_dout->Resize(x.dims());
dev_ctx.template Alloc<T>(d_x_new);
}
if (d_ddx) {
d_dout->Resize(ddx.dims());
dev_ctx.template Alloc<T>(d_ddx);
}
funcs::CosTripleGradFunctor<T> functor;
functor(dev_ctx,
&x,
&ddx,
&dout,
&d_ddout,
&d_dx_new, // input
d_dout,
d_x_new,
d_ddx); // output
}
} // namespace phi
......@@ -503,6 +503,38 @@ class TestSinDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestCosDoubleGradCheck(unittest.TestCase):
def cos_wrapper(self, x):
return paddle.cos(x[0])
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0005
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.cos(x)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(
self.cos_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestPowDoubleGradCheck1(unittest.TestCase):
def pow_wrapper(self, x):
return paddle.pow(x[0], 2)
......@@ -690,5 +722,37 @@ class TestPowTripleGradCheck3(unittest.TestCase):
self.func(p)
class TestCosTripleGradCheck(unittest.TestCase):
def cos_wrapper(self, x):
return paddle.cos(x[0])
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0005
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = layers.cos(x)
x_arr = np.random.random(shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
gradient_checker.triple_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(
self.cos_wrapper, [x], y, x_init=x_arr, place=place
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册