未验证 提交 bbe8f7bd 编写于 作者: L LutaoChu 提交者: GitHub

update cross op parameters for API 2.0

* update cross op parameters
上级 1ab60544
......@@ -79,7 +79,7 @@ class TestCrossAPI(unittest.TestCase):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
y = fluid.layers.data(name='y', shape=[-1, 3])
z = paddle.cross(x, y, dim=1)
z = paddle.cross(x, y, axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'y': self.data_y},
......@@ -103,6 +103,14 @@ class TestCrossAPI(unittest.TestCase):
[-1.0, -1.0, -1.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
# case 3:
with program_guard(Program(), Program()):
x = fluid.data(name="x", shape=[-1, 3], dtype="float32")
y = fluid.data(name='y', shape=[-1, 3], dtype='float32')
y_1 = paddle.cross(x, y, name='result')
self.assertEqual(('result' in y_1.name), True)
def test_dygraph_api(self):
self.input_data()
# case 1:
......@@ -119,7 +127,7 @@ class TestCrossAPI(unittest.TestCase):
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
y = fluid.dygraph.to_variable(self.data_y)
z = paddle.cross(x, y, dim=1)
z = paddle.cross(x, y, axis=1)
np_z = z.numpy()
expect_out = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
......
......@@ -583,66 +583,69 @@ def t(input, name=None):
return out
def cross(input, other, dim=None):
def cross(x, y, axis=None, name=None):
"""
:alias_main: paddle.cross
:alias: paddle.cross,paddle.tensor.cross,paddle.tensor.linalg.cross
Returns the cross product of vectors in dimension `dim` of the `input` and `other` tensor.
Inputs must have the same shape, and the size of their dim-th dimension should be equla to 3.
If `dim` is not given, it defaults to the first dimension found with the size 3.
Computes the cross product between two tensors along an axis.
Inputs must have the same shape, and the length of their axes should be equal to 3.
If `axis` is not given, it defaults to the first axis found with the length 3.
Args:
input (Variable): The first input tensor variable.
other (Variable): The second input tensor variable.
dim (int): The dimension to take the cross-product in.
x (Variable): The first input tensor variable.
y (Variable): The second input tensor variable.
axis (int, optional): The axis along which to compute the cross product. It defaults to the first axis found with the length 3.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: A Tensor with same data type as `input`.
Variable: A Tensor with same data type as `x`.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddle.imperative import to_variable
import numpy as np
paddle.enable_imperative()
data_x = np.array([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
data_y = np.array([[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data_x)
y = fluid.dygraph.to_variable(data_y)
out_z1 = paddle.cross(x, y)
print(out_z1.numpy())
#[[-1. -1. -1.]
# [ 2. 2. 2.]
# [-1. -1. -1.]]
out_z2 = paddle.cross(x, y, dim=1)
print(out_z2.numpy())
#[[0. 0. 0.]
# [0. 0. 0.]
# [0. 0. 0.]]
x = to_variable(data_x)
y = to_variable(data_y)
z1 = paddle.cross(x, y)
print(z1.numpy())
# [[-1. -1. -1.]
# [ 2. 2. 2.]
# [-1. -1. -1.]]
z2 = paddle.cross(x, y, axis=1)
print(z2.numpy())
# [[0. 0. 0.]
# [0. 0. 0.]
# [0. 0. 0.]]
"""
helper = LayerHelper("cross", **locals())
if in_dygraph_mode():
if dim:
return core.ops.cross(input, other, 'dim', dim)
if axis:
return core.ops.cross(x, y, 'dim', axis)
else:
return core.ops.cross(input, other)
return core.ops.cross(x, y)
out = helper.create_variable_for_type_inference(input.dtype)
helper = LayerHelper("cross", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
attrs = dict()
if dim:
attrs['dim'] = dim
attrs['dim'] = axis
helper.append_op(
type='cross',
inputs={'X': input,
'Y': other},
inputs={'X': x,
'Y': y},
outputs={'Out': out},
attrs=attrs)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册