未验证 提交 caea126c 编写于 作者: Z zyfncg 提交者: GitHub

Support custom implement for C++ API (#39521)

* Support custom implement for C++ API

* rename api_invoke_impl to api_custom_impl

* remove manual_api

* delete mutable_data in copy_to api

* fix problem of copy_to

* add unittest for infer_meta_fn_factory

* fix split cofig in yaml

* fix split cofig in yaml

* modify sum api yaml

* add copy_to wrapped infermeta

* rollback copy impl
上级 de8f2748
add_subdirectory(lib)
cc_library(phi_api SRCS all.cc DEPS phi_function_api phi_bw_function_api manual_api sparse_api)
cc_library(phi_api SRCS all.cc DEPS phi_function_api phi_bw_function_api sparse_api)
......@@ -26,7 +26,6 @@ limitations under the License. */
// new pten apis
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/api/include/tensor.h"
......
......@@ -3,11 +3,11 @@ add_subdirectory(utils)
cc_library(ext_compat_utils SRCS ext_compat_utils.cc DEPS place)
if (WITH_GPU)
nv_library(phi_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor phi_api_utils ext_compat_utils phi_enforce manual_api)
nv_library(phi_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor phi_api_utils ext_compat_utils phi_enforce)
elseif (WITH_ROCM)
hip_library(phi_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor phi_api_utils ext_compat_utils phi_enforce manual_api)
hip_library(phi_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor phi_api_utils ext_compat_utils phi_enforce)
else()
cc_library(phi_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor phi_api_utils ext_compat_utils phi_enforce manual_api)
cc_library(phi_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor phi_api_utils ext_compat_utils phi_enforce)
endif()
set(api_gen_base ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_base.py)
......@@ -83,17 +83,16 @@ add_custom_command(
DEPENDS ${api_yaml_file} ${wrapped_infermeta_gen_file} ${api_gen_base}
VERBATIM)
cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor_raw)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi)
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory)
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(manual_api SRCS manual_api.cc DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform)
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform)
cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor)
cc_library(sparse_api SRCS sparse_api.cc DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform)
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform api_custom_impl)
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch phi_data_transform)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch backward_infermeta phi_data_transform phi_function_api api_custom_impl)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi)
cc_library(sparse_api SRCS sparse_api.cc DEPS phi_tensor phi kernel_dispatch phi_data_transform)
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor phi kernel_dispatch phi_data_transform wrapped_infermeta)
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor phi kernel_dispatch phi_data_transform)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor phi kernel_dispatch backward_infermeta phi_data_transform phi_function_api)
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,11 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/include/manual_api.h"
#include <memory>
#include "glog/logging.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_utils.h"
......@@ -25,23 +21,17 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
PD_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(split, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT);
#endif
#ifdef PADDLE_WITH_XPU
PD_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif
#include "glog/logging.h"
namespace paddle {
namespace experimental {
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
......@@ -79,28 +69,15 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
return out;
}
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
DataType kernel_data_type = kernel_key.dtype();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
......@@ -144,7 +121,6 @@ PADDLE_API std::vector<Tensor> split(const Tensor& x,
return out;
}
} // namespace experimental
} // namespace paddle
PD_REGISTER_API(Utils);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -19,22 +19,15 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
/**
* This file stores some special APIs that are implemented manually
* or difficult to automatically generated.
*/
namespace paddle {
namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking);
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking);
// TODO(chentianyu03): Split API has extra logic to calculate the outputs size,
// api_gen do not support
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);
std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);
} // namespace experimental
} // namespace paddle
......@@ -18,5 +18,4 @@ limitations under the License. */
#include "paddle/phi/api/lib/api_registry.h"
PD_DECLARE_API(Math);
PD_DECLARE_API(Utils);
PD_DECLARE_API(SparseApi);
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/storage.h"
......@@ -299,72 +298,7 @@ gpuStream_t Tensor::stream() const {
}
#endif
/* Part 5: Data Transform methods */
template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) const {
LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version "
"2.3, and will be removed in version 2.4, please use "
"`copy_to` method without template argument instead. "
"reason: copying a Tensor to another device does not need "
"to specify the data type template argument.";
return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
}
template PADDLE_API Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<double>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int64_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int32_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<uint8_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int8_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<phi::dtype::complex<float>>(
const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<phi::dtype::complex<double>>(
const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<phi::dtype::float16>(const PlaceType &target_place) const;
Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::copy_to(*this, backend, blocking);
}
void Tensor::copy_(const Tensor &src, bool blocking) {
if (!src.is_initialized()) {
return;
}
VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name();
if (defined()) {
PADDLE_ENFORCE_EQ(dtype(),
src.dtype(),
platform::errors::PreconditionNotMet(
"Tensor %s has different data type with Tensor %s, "
"Tensor Copy cannot be performed!",
name(),
src.name()));
PADDLE_ENFORCE_EQ(impl()->type_info().id(),
src.impl()->type_info().id(),
platform::errors::PreconditionNotMet(
"Tensor %s has different type with Tensor %s, Tensor "
"Copy cannot be performed!",
name(),
src.name()));
}
auto copy_tensor =
src.copy_to(phi::TransToPtenBackend(src.inner_place()), blocking);
set_impl(copy_tensor.impl());
}
/* Part 6: Status utils methods */
/* Part 5: Status utils methods */
bool Tensor::defined() const { return impl_ != nullptr; }
......@@ -376,7 +310,7 @@ bool Tensor::is_initialized() const {
void Tensor::reset() { impl_.reset(); }
/* Part 7: Operator overloading */
/* Part 6: Operator overloading */
Tensor &Tensor::operator=(const Tensor &x) & {
impl_ = x.impl_;
......
......@@ -14,15 +14,83 @@ limitations under the License. */
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_base.h"
namespace paddle {
namespace experimental {
// declare cast api
Tensor cast(const Tensor &x, DataType out_dtype);
Tensor copy_to(const Tensor &x, Backend backend, bool blocking);
Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type);
}
Tensor Tensor::copy_to(Backend backend, bool blocking) const {
return experimental::copy_to(*this, backend, blocking);
}
template <typename T>
Tensor Tensor::copy_to(const PlaceType &target_place) const {
LOG(WARNING) << "The Tensor's `copy_to` method is deprecated since version "
"2.3, and will be removed in version 2.4, please use "
"`copy_to` method without template argument instead. "
"reason: copying a Tensor to another device does not need "
"to specify the data type template argument.";
return copy_to(ConvertExtPlaceToBackend(target_place), /*blocking=*/false);
}
template PADDLE_API Tensor
Tensor::copy_to<float>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<double>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int64_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int32_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<uint8_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int8_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<int16_t>(const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<bool>(const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<phi::dtype::complex<float>>(
const PlaceType &target_place) const;
template PADDLE_API Tensor Tensor::copy_to<phi::dtype::complex<double>>(
const PlaceType &target_place) const;
template PADDLE_API Tensor
Tensor::copy_to<phi::dtype::float16>(const PlaceType &target_place) const;
void Tensor::copy_(const Tensor &src, bool blocking) {
if (!src.is_initialized()) {
return;
}
VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name();
if (defined()) {
PADDLE_ENFORCE_EQ(dtype(),
src.dtype(),
platform::errors::PreconditionNotMet(
"Tensor %s has different data type with Tensor %s, "
"Tensor Copy cannot be performed!",
name(),
src.name()));
PADDLE_ENFORCE_EQ(impl()->type_info().id(),
src.impl()->type_info().id(),
platform::errors::PreconditionNotMet(
"Tensor %s has different type with Tensor %s, Tensor "
"Copy cannot be performed!",
name(),
src.name()));
}
auto copy_tensor =
src.copy_to(phi::TransToPtenBackend(src.inner_place()), blocking);
set_impl(copy_tensor.impl());
}
} // namespace experimental
} // namespace paddle
......@@ -18,7 +18,8 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
# NOTE: Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS softmax_kernel softmax_grad_kernel)
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
......
if(WITH_ROCM)
hip_test(test_phi_tensor SRCS test_pten_tensor.cc DEPS phi_tensor phi_function_api manual_api glog)
hip_test(test_phi_tensor SRCS test_pten_tensor.cc DEPS phi_tensor phi_function_api glog)
else()
cc_test(test_phi_tensor SRCS test_pten_tensor.cc DEPS phi_tensor phi_function_api manual_api glog)
cc_test(test_phi_tensor SRCS test_pten_tensor.cc DEPS phi_tensor phi_function_api glog)
endif()
cc_test(test_phi_exception SRCS test_pten_exception.cc DEPS gtest)
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <memory>
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
......
......@@ -17,7 +17,6 @@
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/phi/kernels/split_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/include/manual_api.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
......
......@@ -34,6 +34,11 @@
kernel :
func : conj
- api : copy_to
args : (Tensor x, Backend backend, bool blocking)
output : Tensor
invoke : copy_to_impl(x, backend, blocking)
- api : divide
args : (Tensor x, Tensor y)
output : Tensor
......@@ -162,6 +167,11 @@
kernel :
func : sign
- api : split
args : (Tensor x, ScalarArray num_or_sections, Scalar axis)
output : Tensor[]
invoke : split_impl(x, num_or_sections, axis)
- api : subtract
args : (Tensor x, Tensor y)
output : Tensor
......@@ -177,7 +187,6 @@
func : SumInferMeta
kernel :
func : sum
param : [x, axis, dtype, keep_dim]
data_type : x
- api : zeros_like
......
......@@ -102,6 +102,7 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
......
......@@ -142,6 +142,7 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册