未验证 提交 fe214af2 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi] Support construct Scalar by using Non-CPU Tensosr (#41528)

* support construct scalar using non-cpu tensor

* fix bugs when run unittest

* fix compile bugs

* fix bugs when run ci

* fix compile bugs

* fix bugs when move copy

* perfect unit test

* perfect unittest

* update according to comment

* add target dependency
上级 b2390438
...@@ -192,13 +192,13 @@ add_subdirectory(profiler) ...@@ -192,13 +192,13 @@ add_subdirectory(profiler)
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
if(WITH_GPU) if(WITH_GPU)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce dynload_cuda new_profiler) nv_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce dynload_cuda new_profiler stats)
nv_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place) nv_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
elseif(WITH_ROCM) elseif(WITH_ROCM)
hip_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce new_profiler) hip_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce new_profiler stats)
hip_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place) hip_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
else() else()
cc_library(profiler SRCS profiler.cc DEPS os_info device_tracer enforce new_profiler) cc_library(profiler SRCS profiler.cc DEPS os_info device_tracer enforce new_profiler stats)
cc_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info place) cc_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info place)
endif() endif()
......
...@@ -23,7 +23,7 @@ add_subdirectory(tools) ...@@ -23,7 +23,7 @@ add_subdirectory(tools)
add_subdirectory(tests) add_subdirectory(tests)
# make an unity target for compile deps # make an unity target for compile deps
set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor) set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor api_scalar)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels}) set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
......
...@@ -164,7 +164,7 @@ cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_conte ...@@ -164,7 +164,7 @@ cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_conte
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor) cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform) cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform) cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform)
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform tensor_copy)
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils) cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils)
...@@ -173,3 +173,5 @@ cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw p ...@@ -173,3 +173,5 @@ cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw p
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api) cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api)
cc_library(strings_api SRCS ${strings_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils) cc_library(strings_api SRCS ${strings_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils)
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api strings_api) cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api strings_api)
cc_library(tensor_copy SRCS tensor_copy.cc DEPS phi_tensor_raw copy_kernel kernel_dispatch api_gen_utils)
cc_library(api_scalar SRCS scalar.cc DEPS tensor_copy)
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/common/type_traits.h" #include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
...@@ -424,35 +425,8 @@ std::vector<std::vector<Tensor>> conv2d_grad_impl( ...@@ -424,35 +425,8 @@ std::vector<std::vector<Tensor>> conv2d_grad_impl(
} }
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
VLOG(6) << "copy API kernel key: " << kernel_key;
VLOG(6) << "copy API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto dense_x = TensorToDenseTensor(x);
Tensor out; Tensor out;
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out); copy(x, place, blocking, &out);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(*dense_x, &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
phi::Place,
bool,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);
return out; return out;
} }
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace experimental {
template <>
ScalarBase<Tensor>::ScalarBase(const Tensor& tensor_in)
: dtype_(tensor_in.dtype()) { // NOLINT
PADDLE_ENFORCE_EQ(tensor_in.numel(),
1,
phi::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, but "
"now Tensor has `%d` elements",
tensor_in.numel()));
if (tensor_in.place() == PlaceType::kGPU) {
Tensor dst_tensor;
copy(tensor_in, phi::CPUPlace(), true, &dst_tensor);
GetDataFromTensor(dst_tensor);
} else if (tensor_in.place() == PlaceType::kCPU) {
GetDataFromTensor(tensor_in);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Now, it is not supported to construct Scalar using tensor that its "
"PlaceType is (%d)",
static_cast<int>(tensor_in.place())));
}
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace experimental {
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) {
auto kernel_key_set = ParseKernelKeyByInputArgs(src);
kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
VLOG(6) << "copy API kernel key: " << kernel_key;
VLOG(6) << "copy API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto dense_x = TensorToDenseTensor(src);
auto kernel_out = SetKernelOutput(kernel_key.backend(), dst);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(*dense_x, &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
phi::Place,
bool,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace experimental {
void copy(const Tensor& src, Place place, bool blocking, Tensor* dst);
} // namespace experimental
} // namespace paddle
cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits scalar string_tensor) tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits string_tensor scalar)
cc_library(phi_place SRCS place.cc) cc_library(phi_place SRCS place.cc)
cc_library(scalar SRCS scalar.cc DEPS phi_enforce) cc_library(scalar SRCS scalar.cc DEPS phi_enforce tensor)
...@@ -14,21 +14,32 @@ limitations under the License. */ ...@@ -14,21 +14,32 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
// NOTE(xiongkun): why we put definition here? // The Tensor must have one dim
// test_custom_op can't include enforce.h, because enforce.h includes gflags. template <>
// so we decouple the include dependence of enforce.h by link. ScalarBase<phi::DenseTensor>::ScalarBase(const phi::DenseTensor& tensor_in)
void ThrowTensorConvertError(int num) { : dtype_(tensor_in.dtype()) { // NOLINT
PADDLE_ENFORCE_EQ(num, PADDLE_ENFORCE_EQ(tensor_in.numel(),
1, 1,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, but " "The Scalar only supports Tensor with 1 element, but "
"now Tensor has `%d` elements", "now Tensor has `%d` elements",
num)); tensor_in.numel()));
auto cpu_place = phi::CPUPlace();
if (!paddle::platform::is_same_place(tensor_in.place(), cpu_place)) {
phi::DenseTensor tensor;
framework::TensorCopySync(tensor_in, cpu_place, &tensor);
GetDataFromTensor(tensor);
} else {
GetDataFromTensor(tensor_in);
}
} }
} // namespace experimental } // namespace experimental
......
...@@ -23,8 +23,6 @@ limitations under the License. */ ...@@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
void ThrowTensorConvertError(int);
template <typename T> template <typename T>
class ScalarBase { class ScalarBase {
public: public:
...@@ -105,50 +103,7 @@ class ScalarBase { ...@@ -105,50 +103,7 @@ class ScalarBase {
} }
// The Tensor must have one dim // The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT ScalarBase(const T& tensor_in); // NOLINT
is_from_tensor_ = true;
ThrowTensorConvertError(tensor.numel());
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}
template <typename OtherT> template <typename OtherT>
ScalarBase(const ScalarBase<OtherT>& other) { ScalarBase(const ScalarBase<OtherT>& other) {
...@@ -200,6 +155,49 @@ class ScalarBase { ...@@ -200,6 +155,49 @@ class ScalarBase {
private: private:
template <typename T1, typename T2> template <typename T1, typename T2>
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst); friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
void GetDataFromTensor(const T& tensor) {
is_from_tensor_ = true;
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}
private: private:
bool is_from_tensor_{false}; bool is_from_tensor_{false};
......
...@@ -23,7 +23,7 @@ cc_library(string_tensor SRCS string_tensor.cc DEPS convert_utils tensor_meta te ...@@ -23,7 +23,7 @@ cc_library(string_tensor SRCS string_tensor.cc DEPS convert_utils tensor_meta te
cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy) cc_library(selected_rows SRCS selected_rows_impl.cc selected_rows.cc DEPS tensor_base dense_tensor phi_enforce ddim memcpy)
cc_library(phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows) cc_library(phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows)
cc_library(custom_kernel SRCS custom_kernel.cc DEPS kernel_factory) cc_library(custom_kernel SRCS custom_kernel.cc DEPS kernel_factory)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/selected_rows.h"
namespace phi {
SelectedRows::SelectedRows(const std::vector<int64_t>& rows,
const int64_t& height)
: impl_(std::make_shared<phi::SelectedRowsImpl>(rows, height)) {}
SelectedRows::SelectedRows()
: impl_(std::make_shared<phi::SelectedRowsImpl>()) {}
} // namespace phi
...@@ -42,10 +42,9 @@ class SelectedRows : public TensorBase, ...@@ -42,10 +42,9 @@ class SelectedRows : public TensorBase,
* *
*/ */
public: public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height) SelectedRows(const std::vector<int64_t>& rows, const int64_t& height);
: impl_(std::make_shared<phi::SelectedRowsImpl>(rows, height)) {}
SelectedRows() : impl_(std::make_shared<phi::SelectedRowsImpl>()) {} SelectedRows();
const DenseTensor& value() const { return impl_->value(); } const DenseTensor& value() const { return impl_->value(); }
......
...@@ -51,7 +51,7 @@ TypeInfo<BaseT> TypeRegistry<BaseT>::RegisterType(const std::string& type) { ...@@ -51,7 +51,7 @@ TypeInfo<BaseT> TypeRegistry<BaseT>::RegisterType(const std::string& type) {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
assert(name_to_id_.find(type) == name_to_id_.end()); assert(name_to_id_.find(type) == name_to_id_.end());
assert(names_.size() < std::numeric_limits<int8_t>::max()); assert(names_.size() < std::numeric_limits<int8_t>::max());
int8_t id = names_.size(); int8_t id = static_cast<int8_t>(names_.size());
names_.emplace_back(type); names_.emplace_back(type);
name_to_id_[type] = id; name_to_id_[type] = id;
return TypeInfo<BaseT>(id); return TypeInfo<BaseT>(id);
......
...@@ -11,14 +11,14 @@ cc_test(test_mean_api SRCS test_mean_api.cc DEPS ${COMMON_API_TEST_DEPS}) ...@@ -11,14 +11,14 @@ cc_test(test_mean_api SRCS test_mean_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS} api_scalar)
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_to_api SRCS test_to_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_to_api SRCS test_to_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_slice_api SRCS test_slice_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_slice_api SRCS test_slice_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_sum_api SRCS test_sum_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_sum_api SRCS test_sum_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_scale_api SRCS test_scale_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_scale_api SRCS test_scale_api.cc DEPS ${COMMON_API_TEST_DEPS} api_scalar)
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_conj_api SRCS test_conj_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_conj_api SRCS test_conj_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_concat_api SRCS test_concat_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_concat_api SRCS test_concat_api.cc DEPS ${COMMON_API_TEST_DEPS})
......
...@@ -2,3 +2,9 @@ cc_test(phi_test_backend SRCS test_backend.cc DEPS gtest) ...@@ -2,3 +2,9 @@ cc_test(phi_test_backend SRCS test_backend.cc DEPS gtest)
cc_test(phi_test_data_layout SRCS test_data_layout.cc DEPS gtest) cc_test(phi_test_data_layout SRCS test_data_layout.cc DEPS gtest)
cc_test(phi_test_data_type SRCS test_data_type.cc DEPS gtest) cc_test(phi_test_data_type SRCS test_data_type.cc DEPS gtest)
cc_test(phi_test_place SRCS test_place.cc DEPS phi_place) cc_test(phi_test_place SRCS test_place.cc DEPS phi_place)
if (WITH_GPU)
nv_test(phi_test_scalar SRCS test_scalar.cu DEPS scalar api_scalar)
endif()
if(WITH_ROCM)
hip_test(phi_test_scalar SRCS test_scalar.cu DEPS scalar api_scalar)
endif()
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <map> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT);
namespace phi {
namespace tests {
using DDim = phi::DDim;
using float16 = phi::dtype::float16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
__global__ void FillTensor(float* data) { data[0] = 1; }
TEST(Scalar, ConstructFromDenseTensor1) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor dense_x(
alloc.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT16, phi::make_ddim({1}), phi::DataLayout::NCHW));
phi::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<float16>(&dense_x);
dense_x_data[0] = 1;
phi::Scalar scalar_test(dense_x);
ASSERT_NEAR(1, scalar_test.to<float16>(), 1e-6);
}
TEST(Scalar, ConstructFromDenseTensor2) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor dense_x(
alloc.get(),
phi::DenseTensorMeta(
phi::DataType::INT16, phi::make_ddim({1}), phi::DataLayout::NCHW));
phi::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<int16_t>(&dense_x);
dense_x_data[0] = 1;
phi::Scalar scalar_test(dense_x);
ASSERT_EQ(1, scalar_test.to<int16_t>());
}
TEST(Scalar, ConstructFromDenseTensor3) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor dense_x(
alloc.get(),
phi::DenseTensorMeta(
phi::DataType::INT8, phi::make_ddim({1}), phi::DataLayout::NCHW));
phi::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<int8_t>(&dense_x);
dense_x_data[0] = 1;
phi::Scalar scalar_test(dense_x);
ASSERT_EQ(1, scalar_test.to<int8_t>());
}
TEST(Scalar, ConstructFromDenseTensor4) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor dense_x(
alloc.get(),
phi::DenseTensorMeta(
phi::DataType::BOOL, phi::make_ddim({1}), phi::DataLayout::NCHW));
phi::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<bool>(&dense_x);
dense_x_data[0] = true;
phi::Scalar scalar_test(dense_x);
ASSERT_EQ(true, scalar_test.to<bool>());
}
TEST(Scalar, ConstructFromDenseTensor5) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor dense_x(alloc.get(),
phi::DenseTensorMeta(phi::DataType::COMPLEX64,
phi::make_ddim({1}),
phi::DataLayout::NCHW));
phi::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<complex64>(&dense_x);
dense_x_data[0] = 1;
phi::Scalar scalar_test(dense_x);
complex64 expected_value(1, 0);
EXPECT_TRUE(expected_value == scalar_test.to<complex64>());
}
TEST(Scalar, ConstructFromDenseTensor6) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::CPUPlace());
phi::DenseTensor dense_x(alloc.get(),
phi::DenseTensorMeta(phi::DataType::COMPLEX128,
phi::make_ddim({1}),
phi::DataLayout::NCHW));
phi::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<complex128>(&dense_x);
dense_x_data[0] = 1;
phi::Scalar scalar_test(dense_x);
complex128 expected_value(1, 0);
EXPECT_TRUE(expected_value == scalar_test.to<complex128>());
}
TEST(Scalar, ConstructFromDenseTensor7) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::GPUPlace());
phi::DenseTensor dense_x(
alloc.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT32, phi::make_ddim({1}), phi::DataLayout::NCHW));
phi::GPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::GPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<float>(&dense_x);
FillTensor<<<1, 1, 0, dev_ctx.stream()>>>(dense_x_data);
dev_ctx.Wait();
phi::Scalar scalar_test(dense_x);
ASSERT_NEAR(1, scalar_test.to<float>(), 1e-6);
}
TEST(Scalar, ConstructFromTensor) {
// 1. create tensor
const auto alloc =
std::make_unique<paddle::experimental::DefaultAllocator>(phi::GPUPlace());
auto dense_x = std::make_shared<phi::DenseTensor>(
alloc.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT32, phi::make_ddim({1}), phi::DataLayout::NCHW));
phi::GPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(phi::GPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<float>(dense_x.get());
FillTensor<<<1, 1, 0, dev_ctx.stream()>>>(dense_x_data);
dev_ctx.Wait();
paddle::experimental::Tensor x(dense_x);
paddle::experimental::Scalar scalar_test(x);
ASSERT_NEAR(1, scalar_test.to<float>(), 1e-6);
}
} // namespace tests
} // namespace phi
cc_test(test_custom_kernel SRCS test_custom_kernel.cc DEPS custom_kernel) cc_test(test_custom_kernel SRCS test_custom_kernel.cc DEPS custom_kernel scalar)
cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor)
cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_type_info SRCS test_type_info.cc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册