未验证 提交 f84ac449 编写于 作者: W wangfengsheng1999 提交者: GitHub

[Cherry-Pick]Support output 0D for...

[Cherry-Pick]Support output 0D for is_empty/as_complex/inner/dot/rank/tensordot/squeeze_/static.accuracy/static.auc/metric.accuracy (#53199)

* support output 0D for is_empty/as_complex/inner/dot/rank/tensordot/squeeze_/static.accuracy/static.auc/metric.accuracy

* test_dot_py

* test_dot_py
上级 7b4badb5
...@@ -1152,8 +1152,9 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { ...@@ -1152,8 +1152,9 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
"with input tensor Y: %s", "with input tensor Y: %s",
x_dims.to_str(), x_dims.to_str(),
y_dims.to_str())); y_dims.to_str()));
std::vector<int64_t> x_dims_vec = phi::vectorize(x_dims);
x_dims[x_dims.size() - 1] = 1; std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 1);
x_dims = phi::make_ddim(x_dims_vec_cut);
out->set_dims(x_dims); out->set_dims(x_dims);
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
out->set_layout(x.layout()); out->set_layout(x.layout());
......
...@@ -479,7 +479,7 @@ void AucInferMeta(const MetaTensor& input, ...@@ -479,7 +479,7 @@ void AucInferMeta(const MetaTensor& input,
0, 0,
phi::errors::InvalidArgument("slide_steps must be natural number")); phi::errors::InvalidArgument("slide_steps must be natural number"));
auc->set_dims({1}); auc->set_dims(phi::make_ddim({}));
auc->set_dtype(DataType::INT64); auc->set_dtype(DataType::INT64);
if (slide_steps) { if (slide_steps) {
......
...@@ -66,11 +66,11 @@ void AccuracyInferMeta(const MetaTensor& out, ...@@ -66,11 +66,11 @@ void AccuracyInferMeta(const MetaTensor& out,
label_dim[0])); label_dim[0]));
} }
accuracy->set_dims({1}); accuracy->set_dims(phi::make_ddim({}));
correct->set_dims(phi::make_ddim({}));
total->set_dims(phi::make_ddim({}));
accuracy->set_dtype(out.dtype()); accuracy->set_dtype(out.dtype());
correct->set_dims({1});
correct->set_dtype(out.dtype()); correct->set_dtype(out.dtype());
total->set_dims({1});
total->set_dtype(out.dtype()); total->set_dtype(out.dtype());
accuracy->share_lod(out); accuracy->share_lod(out);
} }
......
...@@ -1839,7 +1839,7 @@ void InverseInferMeta(const MetaTensor& x, MetaTensor* out) { ...@@ -1839,7 +1839,7 @@ void InverseInferMeta(const MetaTensor& x, MetaTensor* out) {
} }
void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out) { void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1})); out->set_dims(phi::make_ddim({}));
out->set_dtype(DataType::BOOL); out->set_dtype(DataType::BOOL);
} }
......
...@@ -32,7 +32,7 @@ void DotKernel(const Context& dev_ctx, ...@@ -32,7 +32,7 @@ void DotKernel(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (1 == out->dims().size()) { if (out->dims().size() == 0) {
auto eigen_out = phi::EigenScalar<T>::From(*out); auto eigen_out = phi::EigenScalar<T>::From(*out);
auto eigen_x = phi::EigenVector<T>::Flatten(x); auto eigen_x = phi::EigenVector<T>::Flatten(x);
auto eigen_y = phi::EigenVector<T>::Flatten(y); auto eigen_y = phi::EigenVector<T>::Flatten(y);
...@@ -40,7 +40,7 @@ void DotKernel(const Context& dev_ctx, ...@@ -40,7 +40,7 @@ void DotKernel(const Context& dev_ctx,
auto& dev = *dev_ctx.eigen_device(); auto& dev = *dev_ctx.eigen_device();
eigen_out.device(dev) = (eigen_x * eigen_y).sum(); eigen_out.device(dev) = (eigen_x * eigen_y).sum();
} else { } else {
auto eigen_out = phi::EigenMatrix<T>::From(*out); auto eigen_out = phi::EigenVector<T>::From(*out);
auto eigen_x = phi::EigenMatrix<T>::From(x); auto eigen_x = phi::EigenMatrix<T>::From(x);
auto eigen_y = phi::EigenMatrix<T>::From(y); auto eigen_y = phi::EigenMatrix<T>::From(y);
......
...@@ -46,7 +46,7 @@ struct DotGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> { ...@@ -46,7 +46,7 @@ struct DotGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> {
DenseTensor* tensor_dy) { DenseTensor* tensor_dy) {
VLOG(1) << "enable route"; VLOG(1) << "enable route";
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) { if (1 >= tensor_dout->dims().size()) {
auto dout = EigenVector<T>::Flatten(*tensor_dout); auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) { if (tensor_dx) {
...@@ -144,7 +144,7 @@ struct DotGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> { ...@@ -144,7 +144,7 @@ struct DotGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> {
DenseTensor* tensor_dx, DenseTensor* tensor_dx,
DenseTensor* tensor_dy) { DenseTensor* tensor_dy) {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) { if (1 >= tensor_dout->dims().size()) {
auto dout = EigenVector<T>::Flatten(*tensor_dout); auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) { if (tensor_dx) {
auto y = EigenVector<T>::Flatten(*tensor_y); auto y = EigenVector<T>::Flatten(*tensor_y);
...@@ -236,7 +236,7 @@ struct DotDoubleGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> { ...@@ -236,7 +236,7 @@ struct DotDoubleGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> {
const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr(); const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr();
const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr(); const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) { if (1 >= tensor_dout->dims().size()) {
DenseTensor tensor_dout_help; DenseTensor tensor_dout_help;
auto& dev = *ctx.eigen_device(); auto& dev = *ctx.eigen_device();
if (tensor_dx || tensor_dy) { if (tensor_dx || tensor_dy) {
...@@ -431,7 +431,7 @@ struct DotDoubleGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> { ...@@ -431,7 +431,7 @@ struct DotDoubleGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> {
const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr(); const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr();
const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr(); const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) { if (1 >= tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device(); auto& dev = *ctx.eigen_device();
auto x = EigenVector<T>::Flatten(*tensor_x); auto x = EigenVector<T>::Flatten(*tensor_x);
auto y = EigenVector<T>::Flatten(*tensor_y); auto y = EigenVector<T>::Flatten(*tensor_y);
...@@ -621,7 +621,7 @@ struct DotTripleGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> { ...@@ -621,7 +621,7 @@ struct DotTripleGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> {
const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr(); const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr();
const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr(); const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_dout->dims().size()) { if (1 >= in_tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device(); auto& dev = *ctx.eigen_device();
DenseTensor in_tensor_x_help = Conj<T, DeviceContext>(ctx, *in_tensor_x); DenseTensor in_tensor_x_help = Conj<T, DeviceContext>(ctx, *in_tensor_x);
DenseTensor in_tensor_y_help = Conj<T, DeviceContext>(ctx, *in_tensor_y); DenseTensor in_tensor_y_help = Conj<T, DeviceContext>(ctx, *in_tensor_y);
...@@ -1015,7 +1015,7 @@ struct DotTripleGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> { ...@@ -1015,7 +1015,7 @@ struct DotTripleGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> {
const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr(); const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr();
const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr(); const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_dout->dims().size()) { if (1 >= in_tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device(); auto& dev = *ctx.eigen_device();
bool d_dout_flag = false; bool d_dout_flag = false;
bool d_ddx_flag = false; bool d_ddx_flag = false;
......
...@@ -97,7 +97,7 @@ def check(use_cuda): ...@@ -97,7 +97,7 @@ def check(use_cuda):
step += 1 step += 1
print( print(
'iter={:.0f},cost={},acc1={}'.format( 'iter={:.0f},cost={},acc1={}'.format(
step, outs[1][0], outs[2][0] step, outs[1][0], outs[2]
) )
) )
......
...@@ -106,8 +106,7 @@ class DotOpEmptyInput(unittest.TestCase): ...@@ -106,8 +106,7 @@ class DotOpEmptyInput(unittest.TestCase):
x = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32') x = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32')
y = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32') y = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32')
pd_out = paddle.dot(x, y) pd_out = paddle.dot(x, y)
self.assertEqual(pd_out.shape, (0,))
self.assertEqual(pd_out.shape, (0, 1))
def test_3d_input_error(self): def test_3d_input_error(self):
data = np.array([], dtype=np.float32) data = np.array([], dtype=np.float32)
...@@ -127,7 +126,7 @@ class DotOpBatch(DotOp): ...@@ -127,7 +126,7 @@ class DotOpBatch(DotOp):
self.y = ( self.y = (
np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12]) np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12])
) )
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) self.out = np.sum(self.x * self.y, axis=1)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out')
...@@ -180,7 +179,7 @@ class TestDygraph(unittest.TestCase): ...@@ -180,7 +179,7 @@ class TestDygraph(unittest.TestCase):
np.array([[2, 5], [6, 8]]).astype(np.float32) np.array([[2, 5], [6, 8]]).astype(np.float32)
) )
np.testing.assert_array_equal( np.testing.assert_array_equal(
paddle.dot(x1, y1).numpy(), np.array([[17], [58]]) paddle.dot(x1, y1).numpy(), np.array([17, 58])
) )
...@@ -211,7 +210,7 @@ class TestComplexDotOp(OpTest): ...@@ -211,7 +210,7 @@ class TestComplexDotOp(OpTest):
self.out = np.dot(self.x, self.y) self.out = np.dot(self.x, self.y)
def init_grad_input_output(self): def init_grad_input_output(self):
self.grad_out = np.ones(1, self.dtype) + 1j * np.ones(1, self.dtype) self.grad_out = np.ones([], self.dtype) + 1j * np.ones([], self.dtype)
self.grad_x = self.grad_out * np.conj(self.y) self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x) self.grad_y = self.grad_out * np.conj(self.x)
...@@ -269,12 +268,10 @@ class TestComplexDotOp2D(OpTest): ...@@ -269,12 +268,10 @@ class TestComplexDotOp2D(OpTest):
self.y = np.random.random((2, 100)).astype( self.y = np.random.random((2, 100)).astype(
self.dtype self.dtype
) + 1j * np.random.random((2, 100)).astype(self.dtype) ) + 1j * np.random.random((2, 100)).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1) self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1)
def init_grad_input_output(self): def init_grad_input_output(self):
self.grad_out = np.ones((2, 1), self.dtype) + 1j * np.ones( self.grad_out = np.ones((2), self.dtype) + 1j * np.ones((2), self.dtype)
(2, 1), self.dtype
)
self.grad_x = self._get_grad(self.grad_out, self.y) self.grad_x = self._get_grad(self.grad_out, self.y)
self.grad_y = self._get_grad(self.grad_out, self.x) self.grad_y = self._get_grad(self.grad_out, self.x)
...@@ -381,7 +378,7 @@ class DotFP16OpBatch(TestDotFP16Op): ...@@ -381,7 +378,7 @@ class DotFP16OpBatch(TestDotFP16Op):
self.y = ( self.y = (
np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12]) np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12])
) )
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) self.out = np.sum(self.x * self.y, axis=1)
@unittest.skipIf( @unittest.skipIf(
...@@ -468,7 +465,7 @@ class DotBF16OpBatch(TestDotBF16Op): ...@@ -468,7 +465,7 @@ class DotBF16OpBatch(TestDotBF16Op):
self.y = ( self.y = (
np.random.uniform(1, 3, [132]).astype(np.float32).reshape([11, 12]) np.random.uniform(1, 3, [132]).astype(np.float32).reshape([11, 12])
) )
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) self.out = np.sum(self.x * self.y, axis=1)
def test_check_grad_normal(self): def test_check_grad_normal(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
...@@ -1426,8 +1426,9 @@ class TestLayer(LayerTest): ...@@ -1426,8 +1426,9 @@ class TestLayer(LayerTest):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# x = np.random.rand(3, 32, 32).astype("float32") # x = np.random.rand(3, 32, 32).astype("float32")
# y = np.array([[1], [0], [1]]) # y = np.array([[1], [0], [1]])
static_out = exe.run( static_out = exe.run(
feed={"input": x, "label": y}, fetch_list=result[0] feed={"input": x, "label": y}, fetch_list=result
) )
with self.dynamic_graph(force_to_use_cpu=True): with self.dynamic_graph(force_to_use_cpu=True):
......
...@@ -50,11 +50,15 @@ class TestNanInf(unittest.TestCase): ...@@ -50,11 +50,15 @@ class TestNanInf(unittest.TestCase):
assert (out + err).find(b'There are NAN or INF') != -1 assert (out + err).find(b'There are NAN or INF') != -1
def test_nan_inf_in_static_mode(self): def test_nan_inf_in_static_mode(self):
self._python_interp += " check_nan_inf_base.py" self._python_interp += (
" " + os.path.dirname(__file__) + "/check_nan_inf_base.py"
)
self.check_nan_inf() self.check_nan_inf()
def test_nan_inf_in_dynamic_mode(self): def test_nan_inf_in_dynamic_mode(self):
self._python_interp += " check_nan_inf_base_dygraph.py" self._python_interp += (
" " + os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py"
)
self.check_nan_inf() self.check_nan_inf()
......
...@@ -991,6 +991,158 @@ class TestSundryAPI(unittest.TestCase): ...@@ -991,6 +991,158 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x1.grad.shape, [5]) self.assertEqual(x1.grad.shape, [5])
def test_is_empty(self):
# 1) x is 0D
x = paddle.rand([])
out = paddle.is_empty(x)
self.assertFalse(out)
self.assertEqual(out.shape, [])
# 2) x is 1D
x = paddle.rand([5])
out = paddle.is_empty(x)
self.assertFalse(out)
self.assertEqual(out.shape, [])
# 3) x is ND
x = paddle.rand([3, 5])
out = paddle.is_empty(x)
self.assertFalse(out)
self.assertEqual(out.shape, [])
x = paddle.rand([3, 0, 5])
out = paddle.is_empty(x)
self.assertTrue(out)
self.assertEqual(out.shape, [])
def test_squeeze_(self):
# 1) x is 0D
x = paddle.rand([])
x.squeeze_(0)
self.assertEqual(x.shape, [])
# 2) x is 1D
x = paddle.rand([1])
x.squeeze_(0)
self.assertEqual(x.shape, [])
# 3)x is ND
x = paddle.rand([2, 1])
x.squeeze_(1)
self.assertEqual(x.shape, [2])
def test_as_complex(self):
x = paddle.rand([2])
x.stop_gradient = False
out = paddle.as_complex(x)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [2])
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.grad.shape, [])
def test_dot(self):
# 1) x is 1D
x = paddle.rand([2])
x.stop_gradient = False
y = paddle.rand([2])
y.stop_gradient = False
out = paddle.dot(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
# 2) x is 2D
x1 = paddle.rand([2, 2])
x1.stop_gradient = False
y1 = paddle.rand([2, 2])
y1.stop_gradient = False
out1 = paddle.dot(x1, y1)
out1.retain_grads()
out1.backward()
self.assertEqual(x1.grad.shape, [2, 2])
self.assertEqual(out1.shape, [2])
self.assertEqual(out1.grad.shape, [2])
def test_inner(self):
# 0) input is 0D
x = paddle.rand([])
x.stop_gradient = False
y = paddle.rand([])
y.stop_gradient = False
out = paddle.inner(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
# 1) input is 1D
x = paddle.rand([2])
x.stop_gradient = False
y = paddle.rand([2])
y.stop_gradient = False
out = paddle.inner(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
# 2) input is 2D
x = paddle.rand([2, 3])
x.stop_gradient = False
y = paddle.rand([3, 3])
y.stop_gradient = False
out = paddle.inner(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.shape, [2, 3])
self.assertEqual(out.grad.shape, [2, 3])
def test_tensordot(self):
# 1) input is 1D
x = paddle.arange(10, dtype='float64')
x.stop_gradient = False
y = paddle.arange(10, dtype='float64')
y.stop_gradient = False
out = paddle.tensordot(x, y, axes=1)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [10])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
# 2) input is 2D
x = paddle.arange(6, dtype='float64').reshape([2, 3])
y = paddle.arange(6, dtype='float64').reshape([2, 3])
x.stop_gradient = False
out = paddle.tensordot(x, y, axes=2)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
def test_metric_accuracy(self):
x = paddle.full(shape=[2, 4], fill_value=0.25)
y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64")
out = paddle.metric.accuracy(input=x, label=y, k=1)
self.assertEqual(out.shape, [])
def test_std(self): def test_std(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -1098,10 +1250,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1098,10 +1250,6 @@ class TestSundryAPI(unittest.TestCase):
def test_is_tensor(self): def test_is_tensor(self):
self.assertTrue(paddle.is_tensor(self.x)) self.assertTrue(paddle.is_tensor(self.x))
def test_is_empty(self):
x = paddle.rand([3, 0, 5])
self.assertTrue(paddle.is_empty(x))
def test_isfinite(self): def test_isfinite(self):
out = paddle.isfinite(self.x) out = paddle.isfinite(self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True)) np.testing.assert_array_equal(out.numpy(), np.array(True))
...@@ -1160,7 +1308,8 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1160,7 +1308,8 @@ class TestSundryAPI(unittest.TestCase):
def test_rank(self): def test_rank(self):
# 1) x is 0D # 1) x is 0D
out = paddle.rank(self.x) x = paddle.rand([])
out = paddle.rank(x)
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0)) np.testing.assert_array_equal(out.numpy(), np.array(0))
...@@ -2456,6 +2605,230 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2456,6 +2605,230 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (5,)) self.assertEqual(res[2].shape, (5,))
@prog_scope()
def test_is_empty(self):
# 1) x is 0D
x1 = paddle.rand([])
out1 = paddle.is_empty(x1)
# 2) x is 1D
x2 = paddle.rand([5])
out2 = paddle.is_empty(x2)
# 3) x is ND
x3 = paddle.rand([3, 5])
out3 = paddle.is_empty(x3)
x4 = paddle.rand([3, 0, 5])
out4 = paddle.is_empty(x4)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[out1, out2, out3, out4],
)
self.assertEqual(res[0].shape, ())
self.assertFalse(bool(res[0]))
self.assertEqual(res[1].shape, ())
self.assertFalse(bool(res[1]))
self.assertEqual(res[2].shape, ())
self.assertFalse(bool(res[2]))
self.assertEqual(res[3].shape, ())
self.assertTrue(bool(res[3]))
@prog_scope()
def test_as_complex(self):
x = paddle.rand([2])
x.stop_gradient = False
out = paddle.as_complex(x)
self.assertEqual(x.shape, (2,))
self.assertEqual(out.shape, ())
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[x, out, x.grad_name, out.grad_name],
)
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2,))
self.assertEqual(res[3].shape, ())
@prog_scope()
def test_dot(self):
# 1) x is 1d
x = paddle.rand([2])
x.stop_gradient = False
y = paddle.rand([2])
y.stop_gradient = False
out = paddle.dot(x, y)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[x, x.grad_name, out, out.grad_name],
)
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (2,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
# 2) x is 2D
x1 = paddle.rand([2, 2])
x1.stop_gradient = False
y1 = paddle.rand([2, 2])
y1.stop_gradient = False
out1 = paddle.dot(x1, y1)
paddle.static.append_backward(out1.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[x1, x1.grad_name, out1, out1.grad_name],
)
self.assertEqual(res[0].shape, (2, 2))
self.assertEqual(res[1].shape, (2, 2))
self.assertEqual(res[2].shape, (2,))
self.assertEqual(res[3].shape, (2,))
@prog_scope()
def test_inner(self):
# 1) input is 1D
x1 = paddle.rand([2])
x1.stop_gradient = False
y1 = paddle.rand([2])
y1.stop_gradient = False
out1 = paddle.inner(x1, y1)
paddle.static.append_backward(out1.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
x1,
x1.grad_name,
out1,
out1.grad_name,
],
)
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, (2,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
# 2) input is 2D
x = paddle.rand([2, 3])
x.stop_gradient = False
y = paddle.rand([2, 3])
y.stop_gradient = False
out = paddle.inner(x, y)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
x,
x.grad_name,
out,
out.grad_name,
],
)
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (2, 3))
self.assertEqual(res[2].shape, (2, 2))
self.assertEqual(res[3].shape, (2, 2))
@prog_scope()
def test_tensordot(self):
x = paddle.full(shape=[10], fill_value=0.25, dtype='float64')
x.stop_gradient = False
y = paddle.full(shape=[10], fill_value=0.25, dtype='float64')
y.stop_gradient = False
out = paddle.tensordot(x, y, axes=1)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[x, x.grad_name, out, out.grad_name],
)
self.assertEqual(res[0].shape, (10,))
self.assertEqual(res[1].shape, (10,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
x = paddle.arange(6, dtype='float64').reshape([2, 3])
y = paddle.arange(6, dtype='float64').reshape([2, 3])
x.stop_gradient = False
out = paddle.tensordot(x, y, axes=2)
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[x, x.grad_name, out, out.grad_name],
)
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (2, 3))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
@prog_scope()
def test_metric_accuracy(self):
x = paddle.full(shape=[2, 4], fill_value=0.25)
y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64")
out = paddle.metric.accuracy(input=x, label=y, k=1)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[out],
)
self.assertEqual(res[0].shape, ())
@prog_scope()
def test_static_accuracy(self):
x = paddle.full(shape=[2, 4], fill_value=0.25)
y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64")
out = paddle.static.accuracy(input=x, label=y, k=1)
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[out],
)
self.assertEqual(res[0].shape, ())
@prog_scope()
def test_static_auc(self):
x = paddle.full(shape=[3, 2], fill_value=0.25)
y = paddle.full(shape=[3], fill_value=1, dtype="int64")
out = paddle.static.auc(input=x, label=y)[0]
prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[out],
)
self.assertEqual(res[0].shape, ())
@prog_scope() @prog_scope()
def test_std(self): def test_std(self):
x = paddle.rand([]) x = paddle.rand([])
......
...@@ -4197,9 +4197,6 @@ def tensordot(x, y, axes=2, name=None): ...@@ -4197,9 +4197,6 @@ def tensordot(x, y, axes=2, name=None):
shape_out.append(shape_y[i]) shape_out.append(shape_y[i])
not_contraction_size_y *= shape_y[i] not_contraction_size_y *= shape_y[i]
if not shape_out:
shape_out = [1]
x = x.transpose(perm=perm_x).reshape( x = x.transpose(perm=perm_x).reshape(
[not_contraction_size_x, contraction_size] [not_contraction_size_x, contraction_size]
) )
......
...@@ -2070,8 +2070,7 @@ def inner(x, y, name=None): ...@@ -2070,8 +2070,7 @@ def inner(x, y, name=None):
xshape = x.shape xshape = x.shape
yshape = y.shape yshape = y.shape
dstshape = list(xshape[:-1]) + list(yshape[:-1]) dstshape = list(xshape[:-1]) + list(yshape[:-1])
if len(dstshape) == 0:
dstshape = [1]
nx = x.reshape((-1, xshape[-1])) nx = x.reshape((-1, xshape[-1]))
ny = y.reshape((-1, yshape[-1])) ny = y.reshape((-1, yshape[-1]))
......
...@@ -281,7 +281,7 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -281,7 +281,7 @@ TEST(CustomKernel, custom_kernel_dot) {
kernel(&kernel_context); kernel(&kernel_context);
// 8.check result // 8.check result
ASSERT_EQ(dense_out->dims().size(), 2); ASSERT_EQ(dense_out->dims().size(), 1);
ASSERT_EQ(dense_out->dims()[0], 2); ASSERT_EQ(dense_out->dims()[0], 2);
ASSERT_EQ(dense_out->numel(), 2); ASSERT_EQ(dense_out->numel(), 2);
ASSERT_EQ(dense_out->dtype(), phi::DataType::UINT8); ASSERT_EQ(dense_out->dtype(), phi::DataType::UINT8);
......
...@@ -263,7 +263,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -263,7 +263,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
fetch_list=fetch_targets, fetch_list=fetch_targets,
) )
batch_time = (time.time() - start) * 1000 # in miliseconds batch_time = (time.time() - start) * 1000 # in miliseconds
batch_acc1, batch_acc5 = out[1][0], out[2][0] batch_acc1, batch_acc5 = out[1], out[2]
outputs.append(batch_acc1) outputs.append(batch_acc1)
else: else:
# Quant INT8 models do not have accuracy measuring layers # Quant INT8 models do not have accuracy measuring layers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册