未验证 提交 0b608393 编写于 作者: S Scotty 提交者: GitHub

【Complex op】add complex support for index_select and index_sample (#56457)

* support index_select op

* index_sample in cpu

* support index_sample in gpu

* change data_transform

* fix api gen and use skip_transform in yaml
上级 7635af09
......@@ -548,6 +548,7 @@ void GradNodeBase::HandleComplexGradToRealGrad(
for (size_t slot_id = 0; slot_id < out_grads->size(); slot_id++) {
const std::vector<paddle::Tensor>& slot_out_grads = (*out_grads)[slot_id];
for (size_t rank_id = 0; rank_id < slot_out_grads.size(); rank_id++) {
if (bwd_out_meta_[slot_id].size() == 0) continue;
const GradSlotMeta& slot_meta = bwd_out_meta_[slot_id][rank_id];
PADDLE_ENFORCE(
......
......@@ -1120,6 +1120,8 @@
func : index_sample_grad
data_type : out_grad
no_need_buffer : x
data_transform :
skip_transform : index
- backward_op : index_select_grad
forward : index_select(Tensor x, Tensor index, int axis) -> Tensor(out)
......@@ -1132,6 +1134,8 @@
func : index_select_grad
data_type : out_grad
no_need_buffer : x
data_transform :
skip_transform : index
- backward_op : index_select_strided_grad
forward : index_select_strided(Tensor x, int64_t index, int axis) -> Tensor(out)
......
......@@ -14,11 +14,23 @@
import collections
import re
from typing import List
PREFIX_TENSOR_NAME = 'input_'
PREFIX_META_TENSOR_NAME = 'meta_'
def parse_plain_list(s: str, sep=",") -> List[str]:
"""Copy from `paddle/fluid/operators/generator/parse_utils.py`"""
if sep == ",":
patten = re.compile(r',(?![^{]*\})') # support "int[] a={1,2}"
items = re.split(patten, s.strip())
items = [x.strip() for x in items]
return items
else:
return [item.strip() for item in s.strip().split(sep)]
class BaseAPI:
def __init__(self, api_item_yaml):
self.api = self.get_api_name(api_item_yaml)
......@@ -367,14 +379,13 @@ class BaseAPI:
data_transform = {'skip_transform': [], 'support_trans_dtype': []}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
data_transform['skip_transform'] = api_item_yaml[
'data_transform'
]['skip_transform']
data_transform['skip_transform'] = parse_plain_list(
api_item_yaml['data_transform']['skip_transform']
)
if 'support_trans_dtype' in api_item_yaml['data_transform']:
data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform'
]['support_trans_dtype']
data_transform['support_trans_dtype'] = parse_plain_list(
api_item_yaml['data_transform']['support_trans_dtype']
)
return data_transform
# Override by child class
......
......@@ -1238,6 +1238,8 @@
func : index_sample
data_type : x
backward : index_sample_grad
data_transform :
skip_transform : index
- op : index_select
args : (Tensor x, Tensor index, int axis = 0)
......@@ -1248,6 +1250,8 @@
func : index_select
data_type : x
backward : index_select_grad
data_transform :
skip_transform : index
- op : index_select_strided
args : (Tensor x, int64_t index, int axis = 0)
......
......@@ -100,4 +100,6 @@ PD_REGISTER_KERNEL(index_sample_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -115,4 +115,6 @@ PD_REGISTER_KERNEL(index_sample,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -60,5 +60,7 @@ PD_REGISTER_KERNEL(index_select_grad,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
......@@ -59,5 +59,7 @@ PD_REGISTER_KERNEL(index_select,
float,
double,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
......@@ -135,4 +135,6 @@ PD_REGISTER_KERNEL(index_sample_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -108,4 +108,6 @@ PD_REGISTER_KERNEL(index_sample,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -132,5 +132,7 @@ PD_REGISTER_KERNEL(index_select_grad,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
......@@ -85,5 +85,7 @@ PD_REGISTER_KERNEL(index_select,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
int,
int64_t) {}
......@@ -319,7 +319,7 @@ def index_select(x, index, axis=0, name=None):
size as the length of ``index``; other dimensions have the same size as in the ``x`` tensor.
Args:
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float16, float32, float64, int32, int64.
x (Tensor): The input Tensor to be operated. The data of ``x`` can be one of float16, float32, float64, int32, int64, complex64 and complex128.
index (Tensor): The 1-D Tensor containing the indices to index. The data type of ``index`` must be int32 or int64.
axis (int, optional): The dimension in which we index. Default: if None, the ``axis`` is 0.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
......@@ -353,7 +353,16 @@ def index_select(x, index, axis=0, name=None):
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'paddle.tensor.search.index_select',
)
check_variable_and_dtype(
......@@ -771,7 +780,7 @@ def index_sample(x, index):
Args:
x (Tensor): The source input tensor with 2-D shape. Supported data type is
int32, int64, bfloat16, float16, float32, float64.
int32, int64, bfloat16, float16, float32, float64, complex64, complex128.
index (Tensor): The index input tensor with 2-D shape, first dimension should be same with X.
Data type is int32 or int64.
......@@ -826,7 +835,16 @@ def index_sample(x, index):
check_variable_and_dtype(
x,
'x',
['uint16', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'uint16',
'float16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
'paddle.tensor.search.index_sample',
)
check_variable_and_dtype(
......
......@@ -28,6 +28,11 @@ class TestIndexSampleOp(OpTest):
self.python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
if self.x_type == np.complex64 or self.x_type == np.complex128:
xnp = (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.x_type)
indexnp = np.random.randint(
low=0, high=self.x_shape[1], size=self.index_shape
).astype(self.index_type)
......@@ -122,6 +127,28 @@ class TestCase6(TestIndexSampleOp):
self.index_type = "int64"
class TestIndexSampleComplex64(TestIndexSampleOp):
def config(self):
"""
For complex64 x type
"""
self.x_shape = (10, 128)
self.x_type = np.complex64
self.index_shape = (10, 64)
self.index_type = "int64"
class TestIndexSampleComplex128(TestIndexSampleOp):
def config(self):
"""
For complex64 x type
"""
self.x_shape = (10, 128)
self.x_type = np.complex128
self.index_shape = (10, 64)
self.index_type = "int64"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......
......@@ -36,6 +36,11 @@ class TestIndexSelectOp(OpTest):
low=0, high=self.x_shape[self.dim], size=self.index_size
)
x_np = np.random.random(self.x_shape).astype(self.x_type)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x_np = (
np.random.random(self.x_shape)
+ 1j * np.random.random(self.x_shape)
).astype(self.x_type)
self.inputs = {'X': x_np, 'Index': index_np}
self.attrs = {'dim': self.dim}
outer_loop = np.prod(self.x_shape[: self.dim])
......@@ -60,10 +65,16 @@ class TestIndexSelectOp(OpTest):
self.index_size = 100
def test_check_output(self):
self.check_output(check_prim=True)
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_output(check_prim=False)
else:
self.check_output(check_prim=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)
class TestIndexSelectOpCase2(TestIndexSelectOp):
......@@ -146,6 +157,24 @@ class TestIndexSelectBF16Op(OpTest):
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
class TestIndexSelectComplex64(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.complex64
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestIndexSelectComplex128(TestIndexSelectOp):
def init_dtype_type(self):
self.x_type = np.complex128
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册