未验证 提交 d4d3d7ed 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor for xpu kernel, test=kunlun (#47849)

上级 8a339d24
......@@ -169,39 +169,37 @@ struct XPULogGradFunctor : public funcs::BaseActivationFunctor<T> {
const DenseTensor* dOut,
DenseTensor* dX) const {
const T* x_data = nullptr;
const T* y_grad = nullptr;
const T* dout_data = nullptr;
if (x != nullptr) x_data = x->data<T>();
if (dOut != nullptr) y_grad = dOut->data<T>();
T* x_grad = dX->data<T>();
const auto x_dims = x->dims();
auto xshape = vectorize<int>(x_dims);
int len = x->dims()[x_dims.size() - 1];
std::vector<int> yshape(1, len);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
T* y_data = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(y_data);
T* tmp_grad = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(tmp_grad);
int r =
xpu::constant<T>(dev_ctx.x_context(), y_data, len, static_cast<T>(1.0));
if (dOut != nullptr) dout_data = dOut->data<T>();
T* dx_data = dev_ctx.template Alloc<T>(dX);
int r = xpu::constant<T>(
dev_ctx.x_context(), dx_data, x->numel(), static_cast<T>(1.0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
auto x_dims = vectorize<int>(x->dims());
// use [1] to replace [], because xpu not support []
if (x_dims.size() == 0) {
x_dims = std::vector<int>({1});
}
// dx.device(d) = dout * (static_cast<T>(1) / x);
r = xpu::broadcast_div(dev_ctx.x_context(),
reinterpret_cast<const float*>(y_data),
reinterpret_cast<const float*>(dx_data),
reinterpret_cast<const float*>(x_data),
reinterpret_cast<float*>(tmp_grad),
yshape,
xshape);
reinterpret_cast<float*>(dx_data),
x_dims,
x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_div");
r = xpu::broadcast_mul(dev_ctx.x_context(),
reinterpret_cast<const float*>(y_grad),
reinterpret_cast<const float*>(tmp_grad),
reinterpret_cast<float*>(x_grad),
xshape,
xshape);
reinterpret_cast<const float*>(dx_data),
reinterpret_cast<const float*>(dout_data),
reinterpret_cast<float*>(dx_data),
x_dims,
x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul");
}
};
......
......@@ -213,9 +213,14 @@ void PowKernel(const Context& dev_ctx,
static_cast<void*>(&pow_factor),
sizeof(T));
// broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
// std::vector<int>& xshape, const std::vector<int>& yshape);
auto x_dims = vectorize<int>(x.dims());
// use [1] to replace [], because xpu not support []
if (x_dims.size() == 0) {
x_dims = std::vector<int>({1});
}
// broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
// std::vector<int>& xshape, const std::vector<int>& yshape);
int r =
xpu::broadcast_pow(xpu_context, x_data, factor_data, y_data, x_dims, {1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow");
......
......@@ -84,6 +84,17 @@ void XPUElementwise(const XPUContext& dev_ctx,
int ret = xpu::SUCCESS;
// For [2, 3] + [] --> [2, 3] + [1, 1]
// For [] + [2, 3] --> [1, 1] + [2, 3]
// For [] + [], Use [1] + [1] to replace [], because xpu not support []
if (x_dims_vec.size() == 0) {
x_dims_vec = std::vector<int>({1});
}
if (y_dims_vec.size() == 0) {
y_dims_vec = std::vector<int>({1});
}
ret = func(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
......@@ -165,6 +176,15 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx,
dy_data = dev_ctx.template Alloc<T>(dy);
}
// use [1] to replace [], because xpu not support []
if (x_dims_vec.size() == 0) {
x_dims_vec = std::vector<int>({1});
}
if (y_dims_vec.size() == 0) {
y_dims_vec = std::vector<int>({1});
}
int ret = func(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
......
......@@ -75,6 +75,14 @@ void ReduceMaxGradKernel(const Context& dev_ctx,
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
// use [1] to replace [], because xpu not support []
if (xdims.size() == 0) {
xdims = std::vector<int>({1});
}
if (ydims.size() == 0) {
ydims = std::vector<int>({1});
}
// step 1. brocast out and out_grad
int r =
xpu::broadcast<T>(dev_ctx.x_context(), out_data, brocast1, ydims, xdims);
......
......@@ -38,14 +38,8 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
auto reduce_dims = dims_arr.GetData();
std::vector<int> xdims;
for (int i = 0; i < x.dims().size(); i++) {
xdims.push_back(x.dims()[i]);
}
std::vector<int> ydims;
for (int i = 0; i < out_grad.dims().size(); i++) {
ydims.push_back(out_grad.dims()[i]);
}
std::vector<int> xdims = vectorize<int>(x.dims());
std::vector<int> ydims = vectorize<int>(out_grad.dims());
int reduce_numel = 1;
if (reduce_all) {
......@@ -74,6 +68,14 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
dev_ctx.x_context(), x_data, x.numel(), static_cast<XPUType>(val));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
// use [1] to replace [], because xpu not support []
if (xdims.size() == 0) {
xdims = std::vector<int>({1});
}
if (ydims.size() == 0) {
ydims = std::vector<int>({1});
}
r = xpu::broadcast_mul(
dev_ctx.x_context(), x_data, dy_data, x_data, xdims, ydims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul");
......
......@@ -57,6 +57,14 @@ void ReduceSumGradKernel(const Context& dev_ctx,
}
}
// use [1] to replace [], because xpu not support []
if (xdims.size() == 0) {
xdims = std::vector<int>({1});
}
if (ydims.size() == 0) {
ydims = std::vector<int>({1});
}
int r = xpu::broadcast<XPUType>(
dev_ctx.x_context(), out_data, x_grad_data, ydims, xdims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
......
......@@ -31,15 +31,18 @@ void WhereKernel(const Context& ctx,
T* out_data = ctx.template Alloc<T>(out);
auto cond_dims = phi::vectorize<int>(condition.dims());
auto input_dims = phi::vectorize<int>(x.dims());
int ret = xpu::select(ctx.x_context(),
cond_data,
x_data,
y_data,
out_data,
cond_dims,
input_dims);
auto x_dims = phi::vectorize<int>(x.dims());
// use [1] to replace [], because xpu not support []
if (cond_dims.size() == 0) {
cond_dims = std::vector<int>({1});
}
if (x_dims.size() == 0) {
x_dims = std::vector<int>({1});
}
int ret = xpu::select(
ctx.x_context(), cond_data, x_data, y_data, out_data, cond_dims, x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "select");
}
......
......@@ -75,6 +75,10 @@ class XPUTestExpOP(XPUOpTestWrapper):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
class XPUTestExp_ZeroDIm(TestActivationOPBase):
def set_shape(self):
self.shape = []
support_types = get_xpu_op_support_types('exp')
for stype in support_types:
......@@ -100,6 +104,10 @@ class XPUTestSigmoidOP(XPUOpTestWrapper):
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestSigmoid_ZeroDIm(XPUTestSigmoid):
def init_config(self):
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
class XPUTestSigmoid2(XPUTestSigmoid):
def init_config(self):
self.x = np.random.uniform(-2, 2, [100]).astype(self.dtype)
......@@ -310,6 +318,10 @@ class XPUTestLogOP(XPUOpTestWrapper):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
class TestLogCase_ZeroDim(XPUTestLog):
def set_shape(self):
self.shape = []
class TestLogCase1(XPUTestLog):
def set_shape(self):
self.shape = [1, 11, 17]
......@@ -351,6 +363,10 @@ class XPUTestSquareOP(XPUOpTestWrapper):
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestSquare_ZeroDim(XPUTestSquare):
def init_config(self):
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
class XPUTestSquare2(XPUTestSquare):
def init_config(self):
self.x = np.random.uniform(-2, 2, [100]).astype(self.dtype)
......@@ -517,6 +533,10 @@ class XPUTestSoftPlusOP(XPUOpTestWrapper):
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestSoftPlus_ZeroDim(XPUTestSoftPlusBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
class XPUTestSoftPlus2(XPUTestSoftPlusBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype)
......@@ -976,6 +996,10 @@ class XPUTestSwishOP(XPUOpTestWrapper):
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestSwish_ZeroDim(XPUTestSwishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
class XPUTestSwish2(XPUTestSwishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype)
......@@ -1057,6 +1081,10 @@ class XPUTestMishOP(XPUOpTestWrapper):
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestMish_ZeroDim(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, []).astype(self.dtype)
class XPUTestMish2(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype)
......
......@@ -101,6 +101,24 @@ class XPUTestElementwiseAddOp(XPUOpTestWrapper):
def init_max_relative_error(self):
self.max_relative_error = 0.006
class TestElementwiseAddOp_ZeroDim1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(-1, 1, []).astype(self.dtype)
self.y = np.random.uniform(-1, 1, []).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_ZeroDim2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(-1, 1, []).astype(self.dtype)
self.y = np.random.uniform(-1, 1, [13, 17]).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_ZeroDim3(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(-1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(-1, 1, []).astype(self.dtype)
self.out = self.x + self.y
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast."
)
......
......@@ -93,6 +93,22 @@ class XPUTestElementwiseDivOp(XPUOpTestWrapper):
def init_dtype(self):
pass
class TestElementwiseDivOp_ZeroDim1(ElementwiseDivOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, []).astype(self.dtype),
'Y': np.random.uniform(-1, 1, []).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']}
class TestElementwiseDivOp_ZeroDim2(ElementwiseDivOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(-1, 1, []).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']}
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast."
)
......
......@@ -103,6 +103,30 @@ class XPUTestElementwiseMulOp(XPUOpTestWrapper):
def init_axis(self):
pass
class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, []).astype(self.dtype),
'Y': np.random.uniform(-1, 1, []).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_ZeroDim2(ElementwiseMulOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(-1, 1, []).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_ZeroDim3(ElementwiseMulOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, []).astype(self.dtype),
'Y': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast."
)
......
......@@ -80,6 +80,30 @@ class XPUTestElementwiseSubOp(XPUOpTestWrapper):
no_grad_set=set('Y'),
)
class TestElementwiseSubOp_ZeroDim1(TestElementwiseOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, []).astype(self.dtype),
'Y': np.random.uniform(-1, 1, []).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_ZeroDim2(TestElementwiseOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(-1, 1, []).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_ZeroDim3(TestElementwiseOp):
def init_input_output(self):
self.inputs = {
'X': np.random.uniform(-1, 1, []).astype(self.dtype),
'Y': np.random.uniform(-1, 1, [13, 17]).astype(self.dtype),
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast."
)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
import numpy as np
import unittest
paddle.set_device('xpu')
unary_api_list = [
paddle.nn.functional.elu,
paddle.nn.functional.gelu,
paddle.nn.functional.hardsigmoid,
paddle.nn.functional.hardswish,
paddle.nn.functional.leaky_relu,
paddle.nn.functional.log_sigmoid,
paddle.nn.functional.relu,
paddle.nn.functional.relu6,
paddle.nn.functional.sigmoid,
paddle.nn.functional.softplus,
paddle.nn.functional.softshrink,
paddle.nn.functional.softsign,
paddle.nn.functional.swish,
paddle.nn.functional.tanhshrink,
paddle.nn.functional.thresholded_relu,
paddle.stanh,
paddle.nn.functional.celu,
paddle.nn.functional.mish,
paddle.nn.functional.silu,
paddle.nn.functional.tanh,
paddle.cosh,
paddle.sinh,
paddle.abs,
paddle.acos,
paddle.asin,
paddle.atan,
paddle.ceil,
paddle.cos,
paddle.exp,
paddle.floor,
paddle.log,
paddle.log1p,
paddle.reciprocal,
paddle.round,
paddle.sin,
paddle.sqrt,
paddle.square,
paddle.tanh,
paddle.acosh,
paddle.asinh,
paddle.atanh,
paddle.expm1,
paddle.log10,
paddle.log2,
paddle.tan,
]
# Use to test zero-dim in unary API.
class TestUnaryAPI(unittest.TestCase):
def test(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in unary_api_list:
x = paddle.rand([])
x.stop_gradient = False
out = api(x)
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
reduce_api_list = [
paddle.sum,
paddle.mean,
paddle.nansum,
paddle.nanmean,
paddle.min,
paddle.max,
paddle.amin,
paddle.amax,
paddle.prod,
paddle.logsumexp,
paddle.all,
paddle.any,
]
# Use to test zero-dim of reduce API
class TestReduceAPI(unittest.TestCase):
def test(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in reduce_api_list:
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, []).astype('bool')
out = api(x, None)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
else:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
binary_api_list = [
{'func': paddle.add, 'cls_method': '__add__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.pow, 'cls_method': '__pow__'},
]
binary_api_list_without_grad = [
{'func': paddle.equal, 'cls_method': '__eq__'},
{'func': paddle.not_equal, 'cls_method': '__ne__'},
{'func': paddle.greater_equal, 'cls_method': '__ge__'},
{'func': paddle.greater_than, 'cls_method': '__gt__'},
{'func': paddle.less_equal, 'cls_method': '__le__'},
{'func': paddle.less_than, 'cls_method': '__lt__'},
{'func': paddle.remainder, 'cls_method': '__mod__'},
paddle.mod,
paddle.floor_mod,
paddle.logical_and,
paddle.logical_or,
paddle.logical_xor,
]
binary_int_api_list_without_grad = [
paddle.bitwise_and,
paddle.bitwise_or,
paddle.bitwise_xor,
]
# Use to test zero-dim of binary API
class TestBinaryAPI(unittest.TestCase):
def test(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in binary_api_list + binary_api_list_without_grad:
# 1) x/y is 0D
x = paddle.rand([])
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
self.assertEqual(out.shape, [])
if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(y.grad.shape, [])
self.assertEqual(out.grad.shape, [])
# 2) x is not 0D , y is 0D
x = paddle.rand([2, 3, 4])
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
self.assertEqual(out.shape, [2, 3, 4])
if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [2, 3, 4])
self.assertEqual(y.grad.shape, [])
self.assertEqual(out.grad.shape, [2, 3, 4])
# 3) x is 0D , y is not 0D
x = paddle.rand([])
y = paddle.rand([2, 3, 4])
x.stop_gradient = False
y.stop_gradient = False
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
self.assertEqual(out.shape, [2, 3, 4])
if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(y.grad.shape, [2, 3, 4])
self.assertEqual(out.grad.shape, [2, 3, 4])
# 4) x is 0D , y is scalar
x = paddle.rand([])
y = 0.5
x.stop_gradient = False
if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y)
self.assertEqual(out.shape, [])
for api in binary_int_api_list_without_grad:
# 1) x/y is 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, [])
# 2) x is not 0D , y is 0D
x = paddle.randint(-10, 10, [3, 5])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, [3, 5])
# 3) x is 0D , y is not 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [3, 5])
out = api(x, y)
self.assertEqual(out.shape, [3, 5])
paddle.enable_static()
# Use to test zero-dim of Sundry API, which is simple and do
# not have backward, or is not need to test backward in OpTest.
class TestSundryAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.rand([])
def test_linear(self):
x = paddle.randn([3, 2])
w = paddle.full(shape=[2, 4], fill_value=0.5)
b = paddle.zeros([])
np.testing.assert_array_equal(
F.linear(x, w, b).numpy(), F.linear(x, w).numpy()
)
def test_is_floating_point(self):
self.assertTrue(paddle.is_floating_point(self.x))
def test_is_integer(self):
x = paddle.randint(0, 10, [])
self.assertTrue(paddle.is_integer(x))
def test_is_tensor(self):
self.assertTrue(paddle.is_tensor(self.x))
def test_is_empty(self):
x = paddle.rand([3, 0, 5])
self.assertTrue(paddle.is_empty(x))
def test_isfinite(self):
out = paddle.isfinite(self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_isinf(self):
x = paddle.to_tensor(np.array(float('-inf')))
out = paddle.isinf(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_isnan(self):
x = paddle.to_tensor(np.array(float('nan')))
out = paddle.isnan(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_isclose(self):
out = paddle.isclose(self.x, self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_clone(self):
out = paddle.clone(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())
def test_assign(self):
out = paddle.assign(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())
def test_item(self):
x = paddle.full([], 0.5)
self.assertEqual(x.item(), 0.5)
def test_tolist(self):
x = paddle.full([], 0.5)
self.assertEqual(x.tolist(), 0.5)
def test_numpy(self):
x = paddle.full([], 0.5)
np.testing.assert_array_equal(x.numpy(), np.array(0.5))
def test_numel(self):
out = paddle.numel(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(1))
def test_rank(self):
out = paddle.rank(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0))
def test_shape(self):
out = paddle.shape(self.x)
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册