未验证 提交 d0d739ca 编写于 作者: S SylarTiaNII 提交者: GitHub

add abs composite backward op (#50963)

* add abs composite backward op

* add missing changes during merge

* modify according to new rules

* local UT OK

* fix typo

* codestyle

* register composite operator

* add fp16 test for abs

* replace experimenta::tensor
上级 ddfd8f60
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
...@@ -92,6 +95,24 @@ class AbsGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -92,6 +95,24 @@ class AbsGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class AbsCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor input = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor input_grad = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&input_grad);
std::string dx_name = this->GetOutputName(input_grad);
VLOG(6) << "Running abs_grad composite func";
prim::abs_grad<prim::DescTensor>(input, out_grad, dx_ptr);
this->RecoverOutputName(input_grad, dx_name);
}
};
// AbsGrad: dx=dy if x >=0 else -dy // AbsGrad: dx=dy if x >=0 else -dy
// AbsDoubleGrad: ddy = ddx if x >=0 else -ddx // AbsDoubleGrad: ddy = ddx if x >=0 else -ddx
template <typename T> template <typename T>
...@@ -150,6 +171,7 @@ namespace ops = paddle::operators; ...@@ -150,6 +171,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(abs, REGISTER_OPERATOR(abs,
ops::AbsOp, ops::AbsOp,
ops::AbsOpMaker, ops::AbsOpMaker,
ops::AbsCompositeGradOpMaker,
ops::AbsGradMaker<paddle::framework::OpDesc>, ops::AbsGradMaker<paddle::framework::OpDesc>,
ops::AbsGradMaker<paddle::imperative::OpBase>, ops::AbsGradMaker<paddle::imperative::OpBase>,
AbsInferShapeFunctor); AbsInferShapeFunctor);
......
...@@ -453,6 +453,15 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -453,6 +453,15 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
} }
} }
template <typename T>
void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto abs_tmp = abs<T>(x);
auto divide_tmp = divide<T>(x, abs_tmp);
set_output<T>(out_grad * divide_tmp, x_grad);
}
}
template <typename T> template <typename T>
void matmul_double_grad(const Tensor& x, void matmul_double_grad(const Tensor& x,
const Tensor& y, const Tensor& y,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
param : [x] param : [x]
kernel : kernel :
func : abs_grad func : abs_grad
composite : abs_grad(x, out_grad, x_grad)
backward : abs_double_grad backward : abs_double_grad
- backward_op : add_double_grad - backward_op : add_double_grad
......
...@@ -1284,6 +1284,9 @@ class TestRsqrt_ZeroDim(TestRsqrt): ...@@ -1284,6 +1284,9 @@ class TestRsqrt_ZeroDim(TestRsqrt):
class TestAbs(TestActivation): class TestAbs(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "abs" self.op_type = "abs"
self.prim_op_type = "prim"
self.python_api = paddle.abs
self.enable_cinn = False
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
...@@ -1305,7 +1308,7 @@ class TestAbs(TestActivation): ...@@ -1305,7 +1308,7 @@ class TestAbs(TestActivation):
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out', check_eager=False) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestAbs_ZeroDim(TestAbs): class TestAbs_ZeroDim(TestAbs):
...@@ -3828,7 +3831,7 @@ create_test_act_fp16_class(TestTanhshrink) ...@@ -3828,7 +3831,7 @@ create_test_act_fp16_class(TestTanhshrink)
create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestHardShrink)
create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt) create_test_act_fp16_class(TestSqrt)
create_test_act_fp16_class(TestAbs) create_test_act_fp16_class(TestAbs, check_prim=True)
create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False) create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False)
create_test_act_fp16_class(TestCos, grad_atol=0.85) create_test_act_fp16_class(TestCos, grad_atol=0.85)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册