未验证 提交 4e09e402 编写于 作者: C Charles-hit 提交者: GitHub

(cherry-pick)Fix split infershape in static mode and add convert rules for...

(cherry-pick)Fix split infershape in static mode and add convert rules for fill_any_like op (#46079)

* Fix split bug in static mode (#45906)

* fix split bug in static mode

* modify code style

* modify code style

* add unit test for split

* add convert rules for fill_any_like op in paddle science (#45985)

* add convert rules for fill_any_like op in paddle science

* add unit test for fill_any_like op in paddle science

* modify fill_any_like convert rule

* modify fill_any_like convert rule dtype
上级 e25e9471
...@@ -3205,10 +3205,14 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -3205,10 +3205,14 @@ void SplitInferMeta(const MetaTensor& x,
// fill out dims with -1 // fill out dims with -1
if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1 || if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1 ||
(axis_value >= 0 && x.dims().at(axis_value) <= 0)) { (axis_value >= 0 && x.dims().at(axis_value) <= 0)) {
std::vector<phi::DDim> out_dims( std::vector<phi::DDim> out_dims;
sections_data.size(), if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1) {
phi::make_ddim(std::vector<int>(x.dims().size(), -1))); out_dims = std::vector<phi::DDim>(
sections_data.size(),
phi::make_ddim(std::vector<int>(x.dims().size(), -1)));
} else {
out_dims = std::vector<phi::DDim>(sections_data.size(), x.dims());
}
for (size_t i = 0; i < sections_data.size(); ++i) { for (size_t i = 0; i < sections_data.size(); ++i) {
if (axis_value != 0) { if (axis_value != 0) {
// Only pass LoD when not spliting along the first dim. // Only pass LoD when not spliting along the first dim.
...@@ -3293,9 +3297,13 @@ void SplitWithNumInferMeta(const MetaTensor& x, ...@@ -3293,9 +3297,13 @@ void SplitWithNumInferMeta(const MetaTensor& x,
int axis_value = GetSplitAxisValue(x, axis, config); int axis_value = GetSplitAxisValue(x, axis, config);
// fill out dims with -1 // fill out dims with -1
if (axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { if (axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) {
std::vector<phi::DDim> out_dims( std::vector<phi::DDim> out_dims;
num, phi::make_ddim(std::vector<int>(x.dims().size(), -1))); if (axis_value == -1) {
out_dims = std::vector<phi::DDim>(
num, phi::make_ddim(std::vector<int>(x.dims().size(), -1)));
} else {
out_dims = std::vector<phi::DDim>(num, x.dims());
}
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
if (axis_value != 0) { if (axis_value != 0) {
// Only pass LoD when not spliting along the first dim. // Only pass LoD when not spliting along the first dim.
......
...@@ -18,6 +18,7 @@ import paddle ...@@ -18,6 +18,7 @@ import paddle
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose
import paddle.fluid.core as core
paddle.enable_static() paddle.enable_static()
...@@ -343,6 +344,46 @@ class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim): ...@@ -343,6 +344,46 @@ class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']} self.out_map = {0: self.output['Out']}
class TestFillAnyLikeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'fill_any_like'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['fill_any_like', 'fill_constant_p']
self.out_map = {0: self.output['Out']}
class TestFillAnyLikeOrig2Prim2(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'fill_any_like'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dtype': paddle.float32, 'value': 5}
self.orig2prim_args = (X, )
self.all_ops = ['fill_any_like', 'fill_constant_p']
self.out_map = {0: self.output['Out']}
class TestSumOrig2Prim(TestElementWiseAddOrig2Prim): class TestSumOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self): def init_data(self):
......
...@@ -441,6 +441,21 @@ class API_TestSplit5(unittest.TestCase): ...@@ -441,6 +441,21 @@ class API_TestSplit5(unittest.TestCase):
np.testing.assert_allclose(ex_out, re, rtol=1e-05) np.testing.assert_allclose(ex_out, re, rtol=1e-05)
class API_TestSplit6(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.layers.data('data', shape=[-1, 10], dtype='float64')
x0, x1 = paddle.split(data, num_or_sections=[1, 1], axis=0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
input1 = np.random.random([2, 10]).astype('float64')
r0, r1 = exe.run(feed={"data": input1}, fetch_list=[x0, x1])
ex_x0, ex_x1 = np.split(input1, (1, ), axis=0)
np.testing.assert_allclose(ex_x0, r0, rtol=1e-05)
np.testing.assert_allclose(ex_x1, r1, rtol=1e-05)
class API_TestDygraphFluidSplit(unittest.TestCase): class API_TestDygraphFluidSplit(unittest.TestCase):
def test_out1(self): def test_out1(self):
......
...@@ -26,6 +26,8 @@ from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, ...@@ -26,6 +26,8 @@ from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
lookup_orig2prim, lookup_prim2orig, lookup_transpose, lookup_orig2prim, lookup_prim2orig, lookup_transpose,
op_position_inputs, op_position_output) op_position_inputs, op_position_output)
from .utils import INT_DTYPE_2_STRING, get_input_var_list, get_output_var_list from .utils import INT_DTYPE_2_STRING, get_input_var_list, get_output_var_list
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.framework import convert_np_dtype_to_dtype_
def _orig2prim(op, *args): def _orig2prim(op, *args):
...@@ -63,6 +65,7 @@ elementwise_sub ...@@ -63,6 +65,7 @@ elementwise_sub
elementwise_mul elementwise_mul
tanh tanh
fill_zeros_like fill_zeros_like
fill_any_like
sum sum
index_select index_select
scale scale
...@@ -187,6 +190,16 @@ def fill_zeros_like_orig2prim(op, x): ...@@ -187,6 +190,16 @@ def fill_zeros_like_orig2prim(op, x):
return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) return fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
@REGISTER_ORIG2PRIM('fill_any_like')
def fill_any_like_orig2prim(op, x):
if op.attr('dtype') == -1:
return fill_const(value=op.attr('value'), shape=x.shape, dtype=x.dtype)
return fill_const(value=op.attr('value'),
shape=x.shape,
dtype=convert_np_dtype_to_dtype_(
convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')])))
@REGISTER_ORIG2PRIM('sum') @REGISTER_ORIG2PRIM('sum')
def sum_orig2prim(op, xs): def sum_orig2prim(op, xs):
x0 = xs[0] x0 = xs[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册