未验证 提交 8220771b 编写于 作者: X xysheng-baidu 提交者: GitHub

Add flatten composite rule (#50672)

* Add flatten composite rule

* get the right xshape and pass func test

* add cinn unit test

* Remove cinn test, wait for it to be added after repair

* add comp test to test_flatten_contiguous_range_op.py

* remove func test on composite_ops

* Add comments to maybe_wrap_dim func

* remove commented code

* fix the problem with 0D tensor case

* add flatten split rule comment

* fix syntax issues

* block flatten on resnet_prim_cinn

* remove maybe_wrap_dim func

* Use none instead od xshape
上级 c77eb1fd
......@@ -159,6 +159,7 @@ class TestResnet(unittest.TestCase):
not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN"
)
def test_prim_cinn(self):
core._set_prim_forward_blacklist("flatten_contiguous_range")
dy2st_prim_cinn = train(
to_static=True, enable_prim=True, enable_cinn=True
)
......
......@@ -25,21 +25,25 @@ class TestFlattenOp(OpTest):
self.python_api = paddle.flatten
self.python_out_sig = ["Out"]
self.op_type = "flatten_contiguous_range"
self.prim_op_type = "comp"
self.start_axis = 0
self.stop_axis = -1
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float64")}
self.init_attrs()
self.enable_cinn = False
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32"),
}
def test_check_output(self):
self.check_output(no_check_set=["XShape"], check_eager=True)
self.check_output(
no_check_set=["XShape"], check_eager=True, check_prim=True
)
def test_check_grad(self):
self.check_grad(["X"], "Out", check_eager=True)
self.check_grad(["X"], "Out", check_eager=True, check_prim=True)
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......
......@@ -191,6 +191,32 @@ def mean_composite(x, axis, keepdim):
return divide(sum_x, norm)
@REGISTER_COMPOSITE('flatten_contiguous_range')
def flatten_contiguous_range_composite(x, start_axis, stop_axis):
"""
define composite rule of op flatten, flatten_contiguous_range -> flatten.
CINN doesn't need xshape for backward pass, return none instead of xshape.
shape_out is the parameter of reshape, get from start_axis and stop_axis.
out = reshape(x, shape=shape_out), xshape
"""
shape_in = x.shape
start_dim = start_axis if len(shape_in) != 0 else 0
end_dim = stop_axis if len(shape_in) != 0 else 0
assert start_dim <= end_dim
if len(shape_in) == 0 or start_dim == end_dim:
return reshape(x, shape=shape_in), None
slice_numel = 1
for i in range(start_dim, end_dim + 1):
slice_numel *= shape_in[i]
shape_out = []
for i in range(start_dim):
shape_out.append(shape_in[i])
shape_out.append(slice_numel)
for i in range(end_dim + 1, len(shape_in)):
shape_out.append(shape_in[i])
return reshape(x, shape=shape_out), None
@REGISTER_COMPOSITE('dropout')
def dropout_composite(x, seed_tensor, p, is_test, mode, seed, fix_seed):
"""define composite rule of op dropout.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册