未验证 提交 348a36b5 编写于 作者: K Kang Zhao 提交者: GitHub

feat: add composite rule of roll grad (#52532)

* feat: add relu composite rule

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, add python api of relu

* feat: add relu composite rule, commit hook

* fix: maximum type error & ban cinn test

* fix: maximum input sequence bugs

* resolve conflicts

* fix: code style bugs

* add: relu fp16 test

* feat: add rsqrt composite rule

* feat: add rsqrt composite rule

* resolve conflicts of composite rule

* fix: delete check eager

* feat: add roll grad composite rule

* fix minus shift

* fix test roll op
上级 49bbd466
......@@ -1529,5 +1529,22 @@ void gelu_grad(const Tensor& x,
}
}
}
template <typename T>
void roll_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& shifts,
const std::vector<int64_t>& axis,
Tensor* x_grad) {
if (x_grad) {
auto shifts_ = shifts.GetData();
int64_t nums = shifts_.size();
for (int64_t i = 0; i < nums; i++) {
shifts_[i] = 0 - shifts_[i];
}
auto x_grad_output = roll<T>(out_grad, shifts_, axis);
set_output<T>(x_grad_output, x_grad);
}
}
} // namespace prim
} // namespace paddle
......@@ -1286,6 +1286,7 @@
kernel :
func : roll_grad
data_type : x
composite : roll_grad(x, out_grad, shifts, axis, x_grad)
no_need_buffer : x
- backward_op : round_grad
......
......@@ -1201,7 +1201,8 @@ set(TEST_CINN_OPS
test_layer_norm_op
test_cast_op
test_dropout_op
test_group_norm_op)
test_group_norm_op
test_roll_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -26,6 +26,8 @@ class TestRollOp(OpTest):
def setUp(self):
self.python_api = paddle.roll
self.op_type = "roll"
self.public_python_api = paddle.roll
self.prim_op_type = "prim"
self.init_dtype_type()
self.attrs = {'shifts': self.shifts, 'axis': self.axis}
bf16_ut = self.dtype == np.uint16
......@@ -46,10 +48,13 @@ class TestRollOp(OpTest):
self.axis = [0, -2]
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
class TestRollOpCase2(TestRollOp):
......@@ -106,10 +111,10 @@ class TestRollBF16OP(TestRollOp):
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_prim=True)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
@unittest.skipIf(
......@@ -126,10 +131,10 @@ class TestRollBF16OpCase2(TestRollOp):
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_prim=True)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
@unittest.skipIf(
......@@ -146,10 +151,10 @@ class TestRollBF16OpCase3(TestRollOp):
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_prim=True)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out', check_prim=True)
class TestRollAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册