未验证 提交 aaa7cbd5 编写于 作者: Y yaoxuefeng 提交者: GitHub

modify trace api test=develop (#25397)

上级 f9ac5fb9
...@@ -30,8 +30,8 @@ class TraceOp : public framework::OperatorWithKernel { ...@@ -30,8 +30,8 @@ class TraceOp : public framework::OperatorWithKernel {
ctx->HasOutput("Out"), true, ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output of TraceOp is not found.")); platform::errors::NotFound("Output of TraceOp is not found."));
int dim1 = ctx->Attrs().Get<int>("dim1"); int dim1 = ctx->Attrs().Get<int>("axis1");
int dim2 = ctx->Attrs().Get<int>("dim2"); int dim2 = ctx->Attrs().Get<int>("axis2");
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
...@@ -84,15 +84,15 @@ class TraceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,15 +84,15 @@ class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
)DOC") )DOC")
.SetDefault(0); .SetDefault(0);
AddAttr<int>( AddAttr<int>(
"dim1", "axis1",
R"DOC((int, default 0), the first dim of the 2-D planes from which the diagonals should be taken. R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
Can be both positive and negative. Default: 0. Can be either positive or negative. Default: 0.
)DOC") )DOC")
.SetDefault(-2); .SetDefault(-2);
AddAttr<int>( AddAttr<int>(
"dim2", "axis2",
R"DOC((int, default 1), the second dim of the 2-D planes from which the diagonals should be taken. R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
Can be both positive and negative. Default: 1. Can be either positive or negative. Default: 1.
)DOC") )DOC")
.SetDefault(-1); .SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -33,8 +33,8 @@ class TraceCUDAKernel : public framework::OpKernel<T> { ...@@ -33,8 +33,8 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
const int64_t offset = context.Attr<int>("offset"); const int64_t offset = context.Attr<int>("offset");
const int64_t dim1 = context.Attr<int>("dim1"); const int64_t dim1 = context.Attr<int>("axis1");
const int64_t dim2 = context.Attr<int>("dim2"); const int64_t dim2 = context.Attr<int>("axis2");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
const framework::Tensor diag = const framework::Tensor diag =
......
...@@ -174,8 +174,8 @@ class TraceKernel : public framework::OpKernel<T> { ...@@ -174,8 +174,8 @@ class TraceKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
const int64_t offset = context.Attr<int>("offset"); const int64_t offset = context.Attr<int>("offset");
const int64_t dim1 = context.Attr<int>("dim1"); const int64_t dim1 = context.Attr<int>("axis1");
const int64_t dim2 = context.Attr<int>("dim2"); const int64_t dim2 = context.Attr<int>("axis2");
auto output_dims = out->dims(); auto output_dims = out->dims();
...@@ -205,8 +205,8 @@ class TraceGradKernel : public framework::OpKernel<T> { ...@@ -205,8 +205,8 @@ class TraceGradKernel : public framework::OpKernel<T> {
context.Output<framework::Tensor>(framework::GradVarName("Input")); context.Output<framework::Tensor>(framework::GradVarName("Input"));
int64_t offset = context.Attr<int>("offset"); int64_t offset = context.Attr<int>("offset");
int64_t dim1 = context.Attr<int>("dim1"); int64_t dim1 = context.Attr<int>("axis1");
int64_t dim2 = context.Attr<int>("dim2"); int64_t dim2 = context.Attr<int>("axis2");
auto input_dims = d_x->dims(); auto input_dims = d_x->dims();
auto input_stride = framework::stride(input_dims); auto input_stride = framework::stride(input_dims);
......
...@@ -33,7 +33,7 @@ class TestComplexTraceLayer(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestComplexTraceLayer(unittest.TestCase):
for place in self._places: for place in self._places:
with dg.guard(place): with dg.guard(place):
var_x = dg.to_variable(input) var_x = dg.to_variable(input)
result = cpx.trace(var_x, offset=1, dim1=0, dim2=2).numpy() result = cpx.trace(var_x, offset=1, axis1=0, axis2=2).numpy()
target = np.trace(input, offset=1, axis1=0, axis2=2) target = np.trace(input, offset=1, axis1=0, axis2=2)
self.assertTrue(np.allclose(result, target)) self.assertTrue(np.allclose(result, target))
......
...@@ -38,7 +38,7 @@ class TestTraceOp(OpTest): ...@@ -38,7 +38,7 @@ class TestTraceOp(OpTest):
def init_config(self): def init_config(self):
self.case = np.random.randn(20, 6).astype('float64') self.case = np.random.randn(20, 6).astype('float64')
self.inputs = {'Input': self.case} self.inputs = {'Input': self.case}
self.attrs = {'offset': 0, 'dim1': 0, 'dim2': 1} self.attrs = {'offset': 0, 'axis1': 0, 'axis2': 1}
self.target = np.trace(self.inputs['Input']) self.target = np.trace(self.inputs['Input'])
...@@ -46,24 +46,24 @@ class TestTraceOpCase1(TestTraceOp): ...@@ -46,24 +46,24 @@ class TestTraceOpCase1(TestTraceOp):
def init_config(self): def init_config(self):
self.case = np.random.randn(2, 20, 2, 3).astype('float32') self.case = np.random.randn(2, 20, 2, 3).astype('float32')
self.inputs = {'Input': self.case} self.inputs = {'Input': self.case}
self.attrs = {'offset': 1, 'dim1': 0, 'dim2': 2} self.attrs = {'offset': 1, 'axis1': 0, 'axis2': 2}
self.target = np.trace( self.target = np.trace(
self.inputs['Input'], self.inputs['Input'],
offset=self.attrs['offset'], offset=self.attrs['offset'],
axis1=self.attrs['dim1'], axis1=self.attrs['axis1'],
axis2=self.attrs['dim2']) axis2=self.attrs['axis2'])
class TestTraceOpCase2(TestTraceOp): class TestTraceOpCase2(TestTraceOp):
def init_config(self): def init_config(self):
self.case = np.random.randn(2, 20, 2, 3).astype('float32') self.case = np.random.randn(2, 20, 2, 3).astype('float32')
self.inputs = {'Input': self.case} self.inputs = {'Input': self.case}
self.attrs = {'offset': -5, 'dim1': 1, 'dim2': -1} self.attrs = {'offset': -5, 'axis1': 1, 'axis2': -1}
self.target = np.trace( self.target = np.trace(
self.inputs['Input'], self.inputs['Input'],
offset=self.attrs['offset'], offset=self.attrs['offset'],
axis1=self.attrs['dim1'], axis1=self.attrs['axis1'],
axis2=self.attrs['dim2']) axis2=self.attrs['axis2'])
class TestTraceAPICase(unittest.TestCase): class TestTraceAPICase(unittest.TestCase):
...@@ -71,7 +71,7 @@ class TestTraceAPICase(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TestTraceAPICase(unittest.TestCase):
case = np.random.randn(2, 20, 2, 3).astype('float32') case = np.random.randn(2, 20, 2, 3).astype('float32')
data1 = fluid.data(name='data1', shape=[2, 20, 2, 3], dtype='float32') data1 = fluid.data(name='data1', shape=[2, 20, 2, 3], dtype='float32')
out1 = tensor.trace(data1) out1 = tensor.trace(data1)
out2 = tensor.trace(data1, offset=-5, dim1=1, dim2=-1) out2 = tensor.trace(data1, offset=-5, axis1=1, axis2=-1)
place = core.CPUPlace() place = core.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
...@@ -236,39 +236,38 @@ def elementwise_div(x, y, axis=-1, name=None): ...@@ -236,39 +236,38 @@ def elementwise_div(x, y, axis=-1, name=None):
name=name) name=name)
def trace(input, offset=0, dim1=0, dim2=1, name=None): def trace(x, offset=0, axis1=0, axis2=1, name=None):
""" """
The layer to compute the trace for a complex number tensor. input :attr:`input` must be a ComplexVariable. The layer to compute the trace for a complex number tensor. x :attr:`x` must be a ComplexVariable.
See the detailed description for the function and other arguments See the detailed description for the function and other arguments
in :ref:`api_tensor_math_trace` . in :ref:`api_tensor_math_trace` .
Args: Args:
input(ComplexVariable): The input ComplexVariable. Must be at least 2-dimensional. x(ComplexVariable): The input ComplexVariable x. Must be at least 2-dimensional.
The supported data types include complex64 and complex128. The supported data types include complex64 and complex128.
offset(int, optional): Which diagonals in input tensor will be taken. Default: 0 (main diagonals). offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
dim1(int, optional): The first dimension with respect to take diagonal. Default: 0. axis1(int, optional): The first axis with respect to take diagonal. Default: 0.
dim2(int, optional): The second dimension with respect to take diagonal. Default: 1. axis2(int, optional): The second axis with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns: Returns:
ComplexVariable: The trace result of input tensor, it's data type is the same as input data type. ComplexVariable: The trace result of input tensor x, it's data type is the same as input data type.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid.dygraph as dg
import numpy as np import numpy as np
case1 = np.random.randn(3, 10, 10).astype('float64') + 1j * np.random.randn(3, 10, 10).astype('float64') case1 = np.random.randn(3, 10, 10).astype('float64') + 1j * np.random.randn(3, 10, 10).astype('float64')
with dg.guard(): paddle.enable_imperative()
case1 = dg.to_variable(case1) case1 = paddle.imperative.to_variable(case1)
data1 = paddle.complex.trace(case1, offset=1, dim1=1, dim2=2) # data1.shape = [3] data1 = paddle.complex.trace(case1, offset=1, axis1=1, axis2=2) # data1.shape = [3]
""" """
complex_variable_exists([input], "trace") complex_variable_exists([x], "trace")
real = math.trace(input.real, offset, dim1, dim2, name) real = math.trace(x.real, offset, axis1, axis2, name)
imag = math.trace(input.imag, offset, dim1, dim2, name) imag = math.trace(x.imag, offset, axis1, axis2, name)
return ComplexVariable(real, imag) return ComplexVariable(real, imag)
......
...@@ -1572,30 +1572,30 @@ def clamp(input, min=None, max=None, output=None, name=None): ...@@ -1572,30 +1572,30 @@ def clamp(input, min=None, max=None, output=None, name=None):
return output return output
def trace(input, offset=0, dim1=0, dim2=1, out=None, name=None): def trace(x, offset=0, axis1=0, axis2=1, name=None):
""" """
:alias_main: paddle.trace :alias_main: paddle.trace
:alias: paddle.trace,paddle.tensor.trace,paddle.tensor.math.trace :alias: paddle.trace,paddle.tensor.trace,paddle.tensor.math.trace
This OP computes the sum along diagonals of the input tensor. This OP computes the sum along diagonals of the input tensor x.
If ``input`` is 2D, returns the sum of diagonal. If ``x`` is 2D, returns the sum of diagonal.
If ``input`` has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from If ``x`` has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from
the 2D planes specified by dim1 and dim2. By default, the 2D planes formed by the first and second dimensions the 2D planes specified by axis1 and axis2. By default, the 2D planes formed by the first and second axes
of the input tensor. of the input tensor x.
The argument ``offset`` determines where diagonals are taken from input tensor: The argument ``offset`` determines where diagonals are taken from input tensor x:
- If offset = 0, it is the main diagonal. - If offset = 0, it is the main diagonal.
- If offset > 0, it is above the main diagonal. - If offset > 0, it is above the main diagonal.
- If offset < 0, it is below the main diagonal. - If offset < 0, it is below the main diagonal.
Args: Args:
input(Variable): The input tensor. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64. x(Variable): The input tensor x. Must be at least 2-dimensional. The input data type should be float32, float64, int32, int64.
offset(int, optional): Which diagonals in input tensor will be taken. Default: 0 (main diagonals). offset(int, optional): Which diagonals in input tensor x will be taken. Default: 0 (main diagonals).
dim1(int, optional): The first dimension with respect to take diagonal. Default: 0. axis1(int, optional): The first axis with respect to take diagonal. Default: 0.
dim2(int, optional): The second dimension with respect to take diagonal. Default: 1. axis2(int, optional): The second axis with respect to take diagonal. Default: 1.
name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.
Returns: Returns:
...@@ -1605,66 +1605,63 @@ def trace(input, offset=0, dim1=0, dim2=1, out=None, name=None): ...@@ -1605,66 +1605,63 @@ def trace(input, offset=0, dim1=0, dim2=1, out=None, name=None):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid.dygraph as dg
import numpy as np import numpy as np
case1 = np.random.randn(2, 3).astype('float32') case1 = np.random.randn(2, 3).astype('float32')
case2 = np.random.randn(3, 10, 10).astype('float32') case2 = np.random.randn(3, 10, 10).astype('float32')
case3 = np.random.randn(3, 10, 5, 10).astype('float32') case3 = np.random.randn(3, 10, 5, 10).astype('float32')
with dg.guard(): paddle.enable_imperative()
case1 = dg.to_variable(case1)
case2 = dg.to_variable(case2) case1 = paddle.imperative.to_variable(case1)
case3 = dg.to_variable(case3) case2 = paddle.imperative.to_variable(case2)
case3 = paddle.imperative.to_variable(case3)
data1 = paddle.trace(case1) # data1.shape = [1] data1 = paddle.trace(case1) # data1.shape = [1]
data2 = paddle.trace(case2, offset=1, dim1=1, dim2=2) # data2.shape = [3] data2 = paddle.trace(case2, offset=1, axis1=1, axis2=2) # data2.shape = [3]
data3 = paddle.trace(case3, offset=-3, dim1=1, dim2=-1) # data2.shape = [3, 5] data3 = paddle.trace(case3, offset=-3, axis1=1, axis2=-1) # data2.shape = [3, 5]
""" """
inputs = {'Input': [input]} inputs = {'Input': [x]}
attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} attrs = {'offset': offset, 'axis1': axis1, 'axis2': axis2}
def __check_input(input, offset, dim1, dim2): def __check_input(input, offset, dim1, dim2):
check_dtype(input.dtype, 'Input', check_dtype(x.dtype, 'Input',
['int32', 'int64', 'float16', 'float32', 'float64'], ['int32', 'int64', 'float16', 'float32', 'float64'],
'trace') 'trace')
input_shape = list(input.shape) input_shape = list(x.shape)
assert len(input_shape) >= 2, \ assert len(input_shape) >= 2, \
"The input must be at least 2-dimensional, " \ "The x must be at least 2-dimensional, " \
"But received Input's dimensional: %s.\n" % \ "But received Input x's dimensional: %s.\n" % \
len(input_shape) len(input_shape)
dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 axis1_ = axis1 if axis1 >= 0 else len(input_shape) + axis1
dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 axis2_ = axis2 if axis2 >= 0 else len(input_shape) + axis2
assert dim1_ < len(input_shape), \ assert axis1_ < len(input_shape), \
"The argument dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ "The argument axis1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, dim1) % (-(len(input_shape)), len(input_shape) - 1, axis1)
assert dim2_ < len(input_shape), \ assert axis2_ < len(input_shape), \
"The argument dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ "The argument axis2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \
% (-(len(input_shape)), len(input_shape) - 1, dim2) % (-(len(input_shape)), len(input_shape) - 1, axis2)
assert dim1_ != dim2_, \ assert axis1_ != axis2_, \
"dim1 and dim2 cannot be the same dimension." \ "axis1 and axis2 cannot be the same axis." \
"But received dim1 = %d, dim2 = %d\n"%(dim1, dim2) "But received axis1 = %d, axis2 = %d\n"%(axis1, axis2)
if not in_dygraph_mode(): if not in_dygraph_mode():
__check_input(input, offset, dim1, dim2) __check_input(input, offset, axis1, axis2)
helper = LayerHelper('trace', **locals()) helper = LayerHelper('trace', **locals())
if out is None: out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=input.dtype)
else:
check_variable_and_dtype(out, 'out', ['float16', 'float32', 'float64', 'int32', 'int64'], 'trace')
helper.append_op( helper.append_op(
type='trace', type='trace',
inputs={'Input': [input]}, inputs={'Input': [x]},
attrs={'offset': offset, attrs={'offset': offset,
'dim1': dim1, 'axis1': axis1,
'dim2': dim2}, 'axis2': axis2},
outputs={'Out': [out]}) outputs={'Out': [out]})
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册