未验证 提交 8eecd852 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Support construct IntArray by using Non-CPU Tensosr (#41764)

* support construct scalar using non-cpu tensor

* fix bugs when run unittest

* fix compile bugs

* fix bugs when run ci

* fix compile bugs

* fix bugs when move copy

* perfect unit test

* perfect unittest

* update according to comment

* int_array supports constructed by gpu tensor

* add some test

* polish code

* adjust full api

* add unittest

* add unittest
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
上级 aeb33958
...@@ -23,7 +23,7 @@ add_subdirectory(tools) ...@@ -23,7 +23,7 @@ add_subdirectory(tools)
add_subdirectory(tests) add_subdirectory(tests)
# make an unity target for compile deps # make an unity target for compile deps
set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor api_scalar) set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor api_scalar api_int_array)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels}) set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
......
...@@ -176,3 +176,4 @@ cc_library(strings_api SRCS ${strings_api_source_file} DEPS phi_tensor_raw phi k ...@@ -176,3 +176,4 @@ cc_library(strings_api SRCS ${strings_api_source_file} DEPS phi_tensor_raw phi k
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api strings_api) cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api strings_api)
cc_library(tensor_copy SRCS tensor_copy.cc DEPS phi_tensor_raw copy_kernel kernel_dispatch api_gen_utils) cc_library(tensor_copy SRCS tensor_copy.cc DEPS phi_tensor_raw copy_kernel kernel_dispatch api_gen_utils)
cc_library(api_scalar SRCS scalar.cc DEPS tensor_copy) cc_library(api_scalar SRCS scalar.cc DEPS tensor_copy)
cc_library(api_int_array SRCS int_array.cc DEPS tensor_copy)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace experimental {
template <>
IntArrayBase<Tensor>::IntArrayBase(const Tensor& tensor) { // NOLINT
is_from_tensor_ = true;
if (tensor.place().GetType() == phi::AllocationType::CPU) {
AssignDataFromTensor(tensor);
} else {
Tensor tensor_tmp;
copy(tensor, phi::CPUPlace(), true, &tensor_tmp);
AssignDataFromTensor(tensor_tmp);
}
}
template <>
IntArrayBase<Tensor>::IntArrayBase(const std::vector<Tensor>& tensor_list) {
is_from_tensor_ = true;
for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
switch (data_type) {
case DataType::INT32:
if (tensor_list[i].place().GetType() == AllocationType::CPU) {
array_.push_back(*tensor_list[i].template data<int32_t>());
} else {
Tensor tensor_tmp;
copy(tensor_list[i], phi::CPUPlace(), true, &tensor_tmp);
array_.push_back(*tensor_tmp.template data<int32_t>());
}
break;
case DataType::INT64:
if (tensor_list[i].place().GetType() == AllocationType::CPU) {
array_.push_back(*tensor_list[i].template data<int64_t>());
} else {
Tensor tensor_tmp;
copy(tensor_list[i], phi::CPUPlace(), true, &tensor_tmp);
array_.push_back(*tensor_tmp.template data<int64_t>());
}
break;
default:
PD_THROW(
"Data type error. Currently, The data type of IntArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
data_type,
"`.");
}
}
}
} // namespace experimental
} // namespace paddle
cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits string_tensor scalar) tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits string_tensor int_array scalar)
...@@ -67,16 +67,9 @@ phi::IntArray MakePhiIntArray(const paddle::framework::Tensor& src) { ...@@ -67,16 +67,9 @@ phi::IntArray MakePhiIntArray(const paddle::framework::Tensor& src) {
} }
phi::IntArray MakePhiIntArrayFromVar(const framework::Variable& variable) { phi::IntArray MakePhiIntArrayFromVar(const framework::Variable& variable) {
auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) { if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>(); const auto& tensor = variable.Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) { return MakePhiIntArray(tensor);
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePhiIntArray(tmp_tensor);
} else {
return MakePhiIntArray(tensor);
}
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to IntArray when call pt " "Unsupport casting input `%s` type to IntArray when call pt "
......
cc_library(phi_place SRCS place.cc) cc_library(phi_place SRCS place.cc)
cc_library(scalar SRCS scalar.cc DEPS phi_enforce tensor) cc_library(scalar SRCS scalar.cc DEPS phi_enforce tensor)
cc_library(int_array SRCS int_array.cc DEPS phi_enforce tensor)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace experimental {
template <>
IntArrayBase<phi::DenseTensor>::IntArrayBase(
const phi::DenseTensor& tensor) { // NOLINT
is_from_tensor_ = true;
if (tensor.place().GetType() == AllocationType::CPU) {
AssignDataFromTensor(tensor);
} else {
phi::DenseTensor tensor_tmp;
paddle::framework::TensorCopySync(tensor, CPUPlace(), &tensor_tmp);
AssignDataFromTensor(tensor_tmp);
}
}
template <>
IntArrayBase<phi::DenseTensor>::IntArrayBase(
const std::vector<phi::DenseTensor>& tensor_list) {
is_from_tensor_ = true;
for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
switch (data_type) {
case DataType::INT32:
if (tensor_list[i].place().GetType() == AllocationType::CPU) {
array_.push_back(*tensor_list[i].template data<int32_t>());
} else {
phi::DenseTensor tensor_tmp;
paddle::framework::TensorCopySync(
tensor_list[i], CPUPlace(), &tensor_tmp);
array_.push_back(*tensor_tmp.template data<int32_t>());
}
break;
case DataType::INT64:
if (tensor_list[i].place().GetType() == AllocationType::CPU) {
array_.push_back(*tensor_list[i].template data<int64_t>());
} else {
phi::DenseTensor tensor_tmp;
paddle::framework::TensorCopySync(
tensor_list[i], CPUPlace(), &tensor_tmp);
array_.push_back(*tensor_tmp.template data<int64_t>());
}
break;
default:
PD_THROW(
"Data type error. Currently, The data type of IntArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
data_type,
"`.");
}
}
}
} // namespace experimental
} // namespace paddle
...@@ -48,50 +48,10 @@ class IntArrayBase { ...@@ -48,50 +48,10 @@ class IntArrayBase {
void SetFromTensor(bool val) { is_from_tensor_ = val; } void SetFromTensor(bool val) { is_from_tensor_ = val; }
// The Tensor must have one dim // The Tensor must have one dim
IntArrayBase(const T& tensor) { // NOLINT IntArrayBase(const T& tensor); // NOLINT
is_from_tensor_ = true;
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.dtype()) {
case DataType::INT32:
AssignData(tensor.template data<int32_t>(), n);
break;
case DataType::INT64:
AssignData(tensor.template data<int64_t>(), n);
break;
default:
PD_THROW(
"Data type error. Currently, The data type of IntArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
tensor.dtype(),
"`.");
}
}
// The Tensor in vec must have only one element // The Tensor in vec must have only one element
IntArrayBase(const std::vector<T>& tensor_list) { // NOLINT IntArrayBase(const std::vector<T>& tensor_list); // NOLINT
is_from_tensor_ = true;
for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
switch (data_type) {
case DataType::INT32:
array_.push_back(*tensor_list[i].template data<int32_t>());
break;
case DataType::INT64:
array_.push_back(*tensor_list[i].template data<int64_t>());
break;
default:
PD_THROW(
"Data type error. Currently, The data type of IntArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
data_type,
"`.");
}
}
}
template <typename OtherT> template <typename OtherT>
IntArrayBase(const IntArrayBase<OtherT>& other) : array_(other.GetData()) {} IntArrayBase(const IntArrayBase<OtherT>& other) : array_(other.GetData()) {}
...@@ -114,6 +74,26 @@ class IntArrayBase { ...@@ -114,6 +74,26 @@ class IntArrayBase {
} }
} }
void AssignDataFromTensor(const T& tensor) {
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.dtype()) {
case DataType::INT32:
AssignData(tensor.template data<int32_t>(), n);
break;
case DataType::INT64:
AssignData(tensor.template data<int64_t>(), n);
break;
default:
PD_THROW(
"Data type error. Currently, The data type of IntArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
tensor.dtype(),
"`.");
}
}
private: private:
// TODO(zhangyunfei) Replace std::vector with a more efficient container // TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure. // structure.
......
...@@ -2,6 +2,7 @@ cc_test(phi_test_backend SRCS test_backend.cc DEPS gtest) ...@@ -2,6 +2,7 @@ cc_test(phi_test_backend SRCS test_backend.cc DEPS gtest)
cc_test(phi_test_data_layout SRCS test_data_layout.cc DEPS gtest) cc_test(phi_test_data_layout SRCS test_data_layout.cc DEPS gtest)
cc_test(phi_test_data_type SRCS test_data_type.cc DEPS gtest) cc_test(phi_test_data_type SRCS test_data_type.cc DEPS gtest)
cc_test(phi_test_place SRCS test_place.cc DEPS phi_place) cc_test(phi_test_place SRCS test_place.cc DEPS phi_place)
cc_test(phi_test_int_array SRCS test_int_array.cc DEPS int_array api_int_array phi phi_api)
if (WITH_GPU) if (WITH_GPU)
nv_test(phi_test_scalar SRCS test_scalar.cu DEPS scalar api_scalar) nv_test(phi_test_scalar SRCS test_scalar.cu DEPS scalar api_scalar)
endif() endif()
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "gtest/gtest.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT);
#endif
namespace phi {
namespace tests {
TEST(IntArray, ConstructFromCPUDenseTensor) {
auto& pool = paddle::experimental::DeviceContextPool::Instance();
const auto* dev_ctx =
static_cast<const phi::CPUContext*>(pool.Get(CPUPlace()));
phi::DenseTensor shape = Full<int>(*dev_ctx, {2}, 3);
phi::DenseTensor out = Full<int>(*dev_ctx, shape, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
}
TEST(IntArray, ConstructFromCPUDenseTensorVector) {
auto& pool = paddle::experimental::DeviceContextPool::Instance();
const auto* dev_ctx =
static_cast<const phi::CPUContext*>(pool.Get(CPUPlace()));
phi::DenseTensor shape0 = Full<int>(*dev_ctx, {1}, 3);
phi::DenseTensor shape1 = Full<int64_t>(*dev_ctx, {1}, 3);
std::vector<phi::DenseTensor> shape{shape0, shape1};
phi::DenseTensor out = Full<int>(*dev_ctx, shape, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
}
TEST(IntArray, ConstructFromCPUTensor) {
auto shape = paddle::experimental::full({2}, 3, DataType::INT64);
auto out = paddle::experimental::full(shape, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
}
TEST(IntArray, ConstructFromCPUTensorVector) {
auto shape0 = paddle::experimental::full({2}, 3, DataType::INT64);
auto shape1 = paddle::experimental::full({2}, 3, DataType::INT32);
std::vector<paddle::experimental::Tensor> shape{shape0, shape0};
auto out = paddle::experimental::full(shape, 1);
std::vector<paddle::experimental::Tensor> shape_new{shape0, shape1};
auto out1 = paddle::experimental::full(shape_new, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
ASSERT_EQ(out1.dims().size(), 2);
ASSERT_EQ(out1.dims()[0], 3);
ASSERT_EQ(out1.dims()[1], 3);
ASSERT_EQ(out1.numel(), 9);
}
TEST(IntArray, ThrowException) {
auto shape = paddle::experimental::full({2}, 3, DataType::FLOAT32);
auto create_int_array = [&shape]() -> paddle::experimental::IntArray {
paddle::experimental::IntArray int_array{shape};
return int_array;
};
ASSERT_ANY_THROW(create_int_array());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(IntArray, ConstructFromGPUDenseTensor) {
auto& pool = paddle::experimental::DeviceContextPool::Instance();
const auto* dev_ctx =
static_cast<const phi::GPUContext*>(pool.Get(GPUPlace()));
phi::DenseTensor shape = Full<int>(*dev_ctx, {2}, 3);
phi::DenseTensor out = Full<int>(*dev_ctx, shape, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
}
TEST(IntArray, ConstructFromGPUDenseTensorVector) {
auto& pool = paddle::experimental::DeviceContextPool::Instance();
const auto* dev_ctx =
static_cast<const phi::GPUContext*>(pool.Get(GPUPlace()));
phi::DenseTensor shape0 = Full<int>(*dev_ctx, {1}, 3);
phi::DenseTensor shape1 = Full<int64_t>(*dev_ctx, {1}, 3);
std::vector<phi::DenseTensor> shape{shape0, shape1};
phi::DenseTensor out = Full<int>(*dev_ctx, shape, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
}
TEST(IntArray, ConstructFromGPUTensor) {
auto shape = paddle::experimental::full({2}, 3, DataType::INT64, GPUPlace());
auto out = paddle::experimental::full(shape, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
}
TEST(IntArray, ConstructFromGPUTensorVector) {
auto shape0 = paddle::experimental::full({2}, 3, DataType::INT64, GPUPlace());
auto shape1 = paddle::experimental::full({2}, 3, DataType::INT32, GPUPlace());
std::vector<paddle::experimental::Tensor> shape{shape0, shape0};
auto out = paddle::experimental::full(shape, 1);
std::vector<paddle::experimental::Tensor> shape_new{shape0, shape1};
auto out1 = paddle::experimental::full(shape_new, 1);
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
ASSERT_EQ(out1.dims().size(), 2);
ASSERT_EQ(out1.dims()[0], 3);
ASSERT_EQ(out1.dims()[1], 3);
ASSERT_EQ(out1.numel(), 9);
}
#endif
} // namespace tests
} // namespace phi
...@@ -760,8 +760,14 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -760,8 +760,14 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
place = _current_expected_place() place = _current_expected_place()
if force_cpu: if force_cpu:
place = core.CPUPlace() place = core.CPUPlace()
if isinstance(shape, (list, tuple)):
for item in shape:
if not isinstance(item, Variable):
shape = list(
map(lambda x: x.numpy().flat[0] if isinstance(x, Variable) else x,
shape))
break
shape = utils.convert_shape_to_list(shape)
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
out = _C_ops.final_state_full(shape, float(value), dtype, place) out = _C_ops.final_state_full(shape, float(value), dtype, place)
......
...@@ -232,28 +232,33 @@ class TestEmptyAPI(unittest.TestCase): ...@@ -232,28 +232,33 @@ class TestEmptyAPI(unittest.TestCase):
name="shape_tensor_int32", shape=[2], dtype="int32") name="shape_tensor_int32", shape=[2], dtype="int32")
shape_tensor_int64 = fluid.data( shape_tensor_int64 = fluid.data(
name="shape_tensor_int64", shape=[2], dtype="int64") name="shape_tensor_int64", shape=[2], dtype="int64")
shape_tensor_unknown = fluid.data(
name="shape_tensor_unknown", shape=[-1], dtype="int64")
out_1 = paddle.empty(shape=[200, 3], dtype=dtype) out_1 = paddle.empty(shape=[200, 3], dtype=dtype)
out_2 = paddle.empty(shape=shape_tensor_int32, dtype=dtype) out_2 = paddle.empty(shape=shape_tensor_int32, dtype=dtype)
out_3 = paddle.empty(shape=shape_tensor_int64, dtype=dtype) out_3 = paddle.empty(shape=shape_tensor_int64, dtype=dtype)
out_4 = paddle.empty(shape=[200, positive_2_int32], dtype=dtype) out_4 = paddle.empty(shape=[200, positive_2_int32], dtype=dtype)
out_5 = paddle.empty(shape=[200, positive_2_int64], dtype=dtype) out_5 = paddle.empty(shape=[200, positive_2_int64], dtype=dtype)
out_6 = paddle.empty(shape=shape_tensor_unknown, dtype=dtype)
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
res_1, res_2, res_3, res_4, res_5 = exe.run( res_1, res_2, res_3, res_4, res_5, res_6 = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={ feed={
"shape_tensor_int32": np.array([200, 3]).astype("int32"), "shape_tensor_int32": np.array([200, 3]).astype("int32"),
"shape_tensor_int64": np.array([200, 3]).astype("int64"), "shape_tensor_int64": np.array([200, 3]).astype("int64"),
"shape_tensor_unknown": np.array([200, 3]).astype("int64"),
}, },
fetch_list=[out_1, out_2, out_3, out_4, out_5]) fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
self.__check_out__(res_1, dtype) self.__check_out__(res_1, dtype)
self.__check_out__(res_2, dtype) self.__check_out__(res_2, dtype)
self.__check_out__(res_3, dtype) self.__check_out__(res_3, dtype)
self.__check_out__(res_4, dtype) self.__check_out__(res_4, dtype)
self.__check_out__(res_5, dtype) self.__check_out__(res_5, dtype)
self.__check_out__(res_6, dtype)
class TestEmptyError(unittest.TestCase): class TestEmptyError(unittest.TestCase):
......
...@@ -80,8 +80,10 @@ class TestFullAPI(unittest.TestCase): ...@@ -80,8 +80,10 @@ class TestFullAPI(unittest.TestCase):
with fluid.dygraph.base.guard(): with fluid.dygraph.base.guard():
with _test_eager_guard(): with _test_eager_guard():
positive_2_int32 = fluid.layers.fill_constant([1], "int32", 2) positive_2_int32 = fluid.layers.fill_constant([1], "int32", 2)
positive_2_int64 = fluid.layers.fill_constant([1], "int64", 2) positive_2_int64 = fluid.layers.fill_constant([1], "int64", 2)
positive_4_int64 = fluid.layers.fill_constant([1], "int64", 4,
True)
out_1 = paddle.full( out_1 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.1) shape=[1, 2], dtype="float32", fill_value=1.1)
...@@ -108,8 +110,19 @@ class TestFullAPI(unittest.TestCase): ...@@ -108,8 +110,19 @@ class TestFullAPI(unittest.TestCase):
shape=[1], dtype=np.float32, value=1.1) shape=[1], dtype=np.float32, value=1.1)
out_7 = paddle.full( out_7 = paddle.full(
shape=[1, 2], dtype=np.float32, fill_value=val) shape=[1, 2], dtype=np.float32, fill_value=val)
out_8 = paddle.full(
shape=positive_2_int32, dtype="float32", fill_value=1.1)
out_9 = paddle.full(
shape=[
positive_2_int32, positive_2_int64, positive_4_int64
],
dtype="float32",
fill_value=1.1)
# test for numpy.float64 as fill_value # test for numpy.float64 as fill_value
out_8 = paddle.full_like( out_10 = paddle.full_like(
out_7, dtype=np.float32, fill_value=np.abs(1.1)) out_7, dtype=np.float32, fill_value=np.abs(1.1))
assert np.array_equal( assert np.array_equal(
...@@ -133,8 +146,12 @@ class TestFullAPI(unittest.TestCase): ...@@ -133,8 +146,12 @@ class TestFullAPI(unittest.TestCase):
assert np.array_equal( assert np.array_equal(
out_7, np.full( out_7, np.full(
[1, 2], 1.1, dtype="float32")) [1, 2], 1.1, dtype="float32"))
assert np.array_equal(out_8, np.full([2], 1.1, dtype="float32"))
assert np.array_equal(
out_9, np.full(
[2, 2, 4], 1.1, dtype="float32"))
assert np.array_equal( assert np.array_equal(
out_8, np.full( out_10, np.full(
[1, 2], 1.1, dtype="float32")) [1, 2], 1.1, dtype="float32"))
......
...@@ -1530,8 +1530,6 @@ def roll(x, shifts, axis=None, name=None): ...@@ -1530,8 +1530,6 @@ def roll(x, shifts, axis=None, name=None):
axis = [] axis = []
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(shifts, paddle.Tensor):
shifts = shifts.cpu()
return _C_ops.final_state_roll(x, shifts, axis) return _C_ops.final_state_roll(x, shifts, axis)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册