未验证 提交 95c100c1 编写于 作者: T TeslaZhao 提交者: GitHub

op:transpose_op supports bool type (#35886) (#35926)

* Pass compat of conv_transpose_bias_mkldnn_fuse_pass

* Fix a bug of strided_slice op, about the axes parameter access memory out of bounds

* Fix a bug of transpose op, about accessing memory out of bounds of the perm param

* op:transpose_op supports bool type
上级 e8e77ebe
...@@ -44,6 +44,7 @@ template struct SetConstant<platform::CUDADeviceContext, ...@@ -44,6 +44,7 @@ template struct SetConstant<platform::CUDADeviceContext,
platform::complex<double>>; platform::complex<double>>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, bool, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \ template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \ template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
......
...@@ -350,7 +350,8 @@ REGISTER_OPERATOR( ...@@ -350,7 +350,8 @@ REGISTER_OPERATOR(
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>, transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
...@@ -358,6 +359,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -358,6 +359,7 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
...@@ -373,7 +375,8 @@ REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad, ...@@ -373,7 +375,8 @@ REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>); ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>, transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>, ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -383,6 +386,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -383,6 +386,7 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -83,6 +83,7 @@ namespace plat = paddle::platform; ...@@ -83,6 +83,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose, transpose,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
...@@ -92,6 +93,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -92,6 +93,7 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
...@@ -103,6 +105,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -103,6 +105,7 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int32_t>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -114,6 +117,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -114,6 +117,7 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>, ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
......
...@@ -5499,12 +5499,12 @@ def transpose(x, perm, name=None): ...@@ -5499,12 +5499,12 @@ def transpose(x, perm, name=None):
perm[i]-th dimension of `input`. perm[i]-th dimension of `input`.
Args: Args:
x (Tensor): The input Tensor. It is a N-D Tensor of data types float32, float64, int32. x (Tensor): The input Tensor. It is a N-D Tensor of data types bool, float32, float64, int32.
perm (list|tuple): Permute the input according to the data of perm. perm (list|tuple): Permute the input according to the data of perm.
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
Tensor: A transposed n-D Tensor, with data type being float32, float64, int32, int64. Tensor: A transposed n-D Tensor, with data type being bool, float32, float64, int32, int64.
For Example: For Example:
...@@ -5546,7 +5546,7 @@ def transpose(x, perm, name=None): ...@@ -5546,7 +5546,7 @@ def transpose(x, perm, name=None):
return out return out
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'transpose') 'transpose')
check_type(perm, 'perm', (list, tuple), 'transpose') check_type(perm, 'perm', (list, tuple), 'transpose')
if isinstance(perm, tuple): if isinstance(perm, tuple):
......
...@@ -113,6 +113,99 @@ class TestCase9(TestTransposeOp): ...@@ -113,6 +113,99 @@ class TestCase9(TestTransposeOp):
self.axis = (6, 1, 3, 5, 0, 2, 4, 7) self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestTransposeOpBool(TestTransposeOp):
def test_check_grad(self):
pass
class TestTransposeOpBool1D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (100, )
self.axis = (0, )
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool2D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (3, 40)
self.axis = (1, 0)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool3D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (3, 4, 10)
self.axis = (0, 2, 1)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool4D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool5D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool6D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool7D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3)
self.axis = (0, 1, 3, 2, 4, 5, 6)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpBool8D(TestTransposeOpBool):
def initTestCase(self):
self.shape = (2, 3, 2, 3, 2, 4, 3, 3)
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
self.inputs = {'X': np.random.random(self.shape).astype("bool")}
self.outputs = {
'XShape': np.random.random(self.shape).astype("bool"),
'Out': self.inputs['X'].transpose(self.axis)
}
class TestTransposeOpError(unittest.TestCase): class TestTransposeOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static() paddle.enable_static()
...@@ -126,9 +219,9 @@ class TestTransposeOpError(unittest.TestCase): ...@@ -126,9 +219,9 @@ class TestTransposeOpError(unittest.TestCase):
self.assertRaises(TypeError, test_x_Variable_check) self.assertRaises(TypeError, test_x_Variable_check)
def test_x_dtype_check(): def test_x_dtype_check():
# the Input(x)'s dtype must be one of [float16, float32, float64, int32, int64] # the Input(x)'s dtype must be one of [bool, float16, float32, float64, int32, int64]
x1 = fluid.layers.data( x1 = fluid.layers.data(
name='x1', shape=[10, 5, 3], dtype='bool') name='x1', shape=[10, 5, 3], dtype='int8')
fluid.layers.transpose(x1, perm=[1, 0, 2]) fluid.layers.transpose(x1, perm=[1, 0, 2])
self.assertRaises(TypeError, test_x_dtype_check) self.assertRaises(TypeError, test_x_dtype_check)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册