未验证 提交 cbce0e60 编写于 作者: L Leo Chen 提交者: GitHub

avoid custom kernel deps on pten_function_api (#39661)

* pten matmul cuda kernel support bf16

* avoid custom kernel deps on pten_function_api

* Revert "pten matmul cuda kernel support bf16"

This reverts commit 5d520845b9a189375677276efb673235ed8e5ee0.

* refine code

* fix compile

* fix test_split_api
上级 f86073c4
......@@ -3,19 +3,13 @@ add_subdirectory(utils)
cc_library(ext_compat_utils SRCS ext_compat_utils.cc DEPS place)
if (WITH_GPU)
nv_library(pten_tensor SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils pten_enforce manual_api pten_function_api)
nv_library(pten_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils pten_enforce manual_api)
elseif (WITH_ROCM)
hip_library(pten_tensor SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils pten_enforce manual_api pten_function_api)
hip_library(pten_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils pten_enforce manual_api)
else()
cc_library(pten_tensor SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils pten_enforce manual_api pten_function_api)
cc_library(pten_tensor_raw SRCS tensor.cc DEPS tensor_base dense_tensor pten_api_utils ext_compat_utils pten_enforce manual_api)
endif()
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS pten_tensor pten_context kernel_factory)
cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor)
set(api_gen_base ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_base.py)
# forward api file
......@@ -81,8 +75,15 @@ add_custom_command(
DEPENDS ${api_yaml_file} ${wrapped_infermeta_gen_file} ${api_gen_base}
VERBATIM)
cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS pten_tensor_raw pten_context kernel_factory)
cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor_raw pten kernel_dispatch pten_data_transform)
cc_library(pten_tensor SRCS tensor_method.cc DEPS pten_tensor_raw pten_function_api)
cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/pten/infermeta/unary.h"
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(split, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, GPU, ALL_LAYOUT);
......
......@@ -58,9 +58,6 @@ limitations under the License. */
namespace paddle {
namespace experimental {
// declare cast api
Tensor cast(const Tensor &x, DataType out_dtype);
/////// Tensor Methods ////////
/* Part 1: Construction and destruction methods */
......@@ -363,9 +360,6 @@ void Tensor::copy_(const Tensor &src, bool blocking) {
src.copy_to(pten::TransToPtenBackend(src.inner_place()), blocking);
set_impl(copy_tensor.impl());
}
Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type);
}
/* Part 6: Status utils methods */
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/api/include/tensor.h"
namespace paddle {
namespace experimental {
// declare cast api
Tensor cast(const Tensor &x, DataType out_dtype);
Tensor Tensor::cast(DataType target_type) const {
return experimental::cast(*this, target_type);
}
} // namespace experimental
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册