未验证 提交 dc331231 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomDevice] support scalar (#45244)

上级 a8ae87f1
......@@ -37,6 +37,12 @@ ScalarBase<Tensor>::ScalarBase(const Tensor& tensor_in)
GetDataFromTensor(dst_tensor);
} else if (tensor_in_place == phi::AllocationType::CPU) {
GetDataFromTensor(tensor_in);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (tensor_in_place == phi::AllocationType::CUSTOM) {
Tensor dst_tensor;
copy(tensor_in, phi::CPUPlace(), true, &dst_tensor);
GetDataFromTensor(dst_tensor);
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Now, it is not supported to construct Scalar using tensor that its "
......
......@@ -41,6 +41,7 @@ class TestCustomCPUPlugin(unittest.TestCase):
self._test_eager_backward_api()
self._test_eager_copy_to()
self._test_fallback_kernel()
self._test_scalar()
self._test_custom_device_dataloader()
self._test_custom_device_mnist()
......@@ -170,6 +171,13 @@ class TestCustomCPUPlugin(unittest.TestCase):
z = paddle.add(x, y)
np.testing.assert_array_equal(z, r)
def _test_scalar(self):
import paddle
data_1 = paddle.to_tensor([[[[1.0, 4.0, 5.0, 7.0], [3.0, 4.0, 5.0,
6.0]]]])
k_t = paddle.to_tensor([3], dtype="int32")
value_1, indices_1 = paddle.topk(data_1, k=k_t)
def tearDown(self):
del os.environ['CUSTOM_DEVICE_ROOT']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册