未验证 提交 dcff7fa8 编写于 作者: Z zyfncg 提交者: GitHub

【Pten】Support data transform in C++ API (#39263)

* add data_transform in pten api

* support GetKernelTypeForVar

* fix complie problem of bfloat16

* change error namespace

* add complex type transform unittest

* fix merge conflict
上级 0dccdee0
......@@ -4,8 +4,10 @@ paddle/fluid/API_DEV.spec
paddle/fluid/API_PR.spec
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/pten/api/*/api.*
paddle/pten/api/*/backward*
paddle/pten/api/include/api.h
paddle/pten/api/lib/api.cc
paddle/pten/api/backward/backward_api.h
paddle/pten/api/lib/backward_api.cc
paddle/pten/include/*
paddle/pten/extension.h
paddle/fluid/eager/api/generated/*
......
......@@ -15,6 +15,9 @@ cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS pten_tensor pten_context
cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor)
set(api_gen_utils ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/gen_utils.py)
# forward api file
set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
......@@ -46,7 +49,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file} ${api_gen_file}
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_utils}
VERBATIM)
# generate backward api
......@@ -59,10 +62,11 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_header_file_tmp} ${bw_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_source_file_tmp} ${bw_api_source_file}
COMMENT "copy_if_different ${bw_api_header_file} ${bw_api_source_file}"
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file}
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_utils}
VERBATIM)
cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_function_api)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/transfer_layout_kernel.h"
#include "paddle/fluid/framework/data_device_transform.h"
namespace paddle {
namespace experimental {
inline bool NeedTransformDataType(const DataType& input,
const DataType& target,
const TransformFlag& transform_flag) {
return input != target &&
(transform_flag.need_trans_data_type() ||
target == DataType::COMPLEX64 || target == DataType::COMPLEX128);
}
inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target,
const TransformFlag& transform_flag) {
bool ret = transform_flag.need_trans_backend() &&
target != Backend::ALL_BACKEND &&
!platform::is_same_place(input, pten::TransToFluidPlace(target));
return ret;
}
inline bool NeedTransformLayout(const DataLayout& input,
const DataLayout& target,
const TransformFlag& transform_flag) {
bool ret = transform_flag.need_trans_layout() &&
(input != DataLayout::ALL_LAYOUT &&
target != DataLayout::ALL_LAYOUT && input != target);
return ret;
}
inline pten::DenseTensor TransDataLayout(const pten::DenseTensor& tensor,
DataLayout layout) {
auto& pool = paddle::platform::DeviceContextPool::Instance();
VLOG(3) << "DataLayoutTransform src_layout: " << tensor.layout()
<< " dst_layout: " << layout;
if (platform::is_cpu_place(tensor.place())) {
auto* dev_ctx = static_cast<pten::CPUContext*>(pool.Get(tensor.place()));
return pten::TransferLayout(*dev_ctx, tensor, layout);
} else {
PADDLE_THROW(pten::errors::PreconditionNotMet(
"Unsupported data layout cast from CPU to GPU."));
}
}
template <typename Context>
pten::DenseTensor CastDateType(const Context& dev_ctx,
const pten::DenseTensor& tensor,
DataType dtype) {
switch (tensor.dtype()) {
case DataType::FLOAT32:
return pten::Cast<float>(dev_ctx, tensor, dtype);
case DataType::FLOAT64:
return pten::Cast<double>(dev_ctx, tensor, dtype);
case DataType::INT32:
return pten::Cast<int32_t>(dev_ctx, tensor, dtype);
case DataType::INT64:
return pten::Cast<int64_t>(dev_ctx, tensor, dtype);
case DataType::FLOAT16:
return pten::Cast<pten::dtype::float16>(dev_ctx, tensor, dtype);
case DataType::BFLOAT16:
return pten::Cast<pten::dtype::bfloat16>(dev_ctx, tensor, dtype);
case DataType::BOOL:
return pten::Cast<bool>(dev_ctx, tensor, dtype);
case DataType::INT16:
return pten::Cast<int16_t>(dev_ctx, tensor, dtype);
case DataType::UINT8:
return pten::Cast<uint8_t>(dev_ctx, tensor, dtype);
default:
PADDLE_THROW(pten::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
tensor.dtype()));
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
pten::DenseTensor CastDateType(const pten::GPUContext& dev_ctx,
const pten::DenseTensor& tensor,
DataType dtype) {
switch (tensor.dtype()) {
case DataType::FLOAT32:
return pten::Cast<float>(dev_ctx, tensor, dtype);
case DataType::FLOAT64:
return pten::Cast<double>(dev_ctx, tensor, dtype);
case DataType::INT32:
return pten::Cast<int32_t>(dev_ctx, tensor, dtype);
case DataType::INT64:
return pten::Cast<int64_t>(dev_ctx, tensor, dtype);
case DataType::FLOAT16:
return pten::Cast<pten::dtype::float16>(dev_ctx, tensor, dtype);
case DataType::BOOL:
return pten::Cast<bool>(dev_ctx, tensor, dtype);
case DataType::INT16:
return pten::Cast<int16_t>(dev_ctx, tensor, dtype);
case DataType::UINT8:
return pten::Cast<uint8_t>(dev_ctx, tensor, dtype);
default:
PADDLE_THROW(pten::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
tensor.dtype()));
}
}
#endif
inline pten::DenseTensor TransDataType(const pten::DenseTensor& tensor,
DataType dtype) {
auto& pool = paddle::platform::DeviceContextPool::Instance();
VLOG(3) << "DataTypeTransform src_dtype: " << tensor.dtype()
<< " dst_dtype: " << dtype;
pten::DenseTensor out(
pten::make_intrusive<paddle::experimental::SharedStorage>(tensor.place()),
{dtype, tensor.dims(), tensor.layout()});
if (platform::is_cpu_place(tensor.place())) {
auto* dev_ctx = static_cast<pten::CPUContext*>(pool.Get(tensor.place()));
return CastDateType(*dev_ctx, tensor, dtype);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (platform::is_gpu_place(tensor.place())) {
auto* dev_ctx = static_cast<pten::GPUContext*>(pool.Get(tensor.place()));
return CastDateType(*dev_ctx, tensor, dtype);
#endif
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Place type is not supported when casting data type."));
}
return out;
}
pten::DenseTensor TransformData(const pten::DenseTensor& tensor,
const pten::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
pten::DenseTensor out = tensor;
if (NeedTransformLayout(
tensor.layout(), target_args_def.layout, transform_flag)) {
out = TransDataLayout(out, target_args_def.layout);
}
if (NeedTransformDataType(
tensor.dtype(), target_args_def.dtype, transform_flag)) {
out = TransDataType(out, target_args_def.dtype);
}
if (NeedTransformPlace(
out.place(), target_args_def.backend, transform_flag)) {
pten::DenseTensor result(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(target_args_def.backend)),
{out.dtype(), out.dims(), out.layout()});
framework::TransDataDevice(
out, pten::TransToFluidPlace(target_args_def.backend), &result);
out = result;
}
return out;
}
std::shared_ptr<pten::DenseTensor> PrepareData(
const Tensor& input,
const pten::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
const auto& tensor_in = input.impl();
if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
(!NeedTransformPlace(
tensor_in->place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
tensor_in->dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(
tensor_in->layout(), target_args_def.layout, transform_flag))) {
return std::dynamic_pointer_cast<pten::DenseTensor>(tensor_in);
}
pten::DenseTensor out =
TransformData(*(static_cast<pten::DenseTensor*>(tensor_in.get())),
target_args_def,
transform_flag);
return std::make_shared<pten::DenseTensor>(out);
}
std::unique_ptr<std::vector<pten::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const pten::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
auto pt_tensors = std::make_unique<std::vector<pten::DenseTensor>>();
pt_tensors->reserve(inputs.size());
for (const auto& input : inputs) {
const auto& tensor_in = input.impl();
if (!transform_flag.NeedTransform() || !tensor_in->initialized() ||
(!NeedTransformPlace(
tensor_in->place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
tensor_in->dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(
tensor_in->layout(), target_args_def.layout, transform_flag))) {
pt_tensors->emplace_back(
*std::dynamic_pointer_cast<pten::DenseTensor>(tensor_in));
} else {
pt_tensors->emplace_back(
TransformData(*(static_cast<pten::DenseTensor*>(tensor_in.get())),
target_args_def,
transform_flag));
}
}
return std::move(pt_tensors);
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/core/kernel_factory.h"
namespace paddle {
namespace experimental {
class TransformFlag {
public:
TransformFlag(bool stop_transform = false,
bool trans_dtype = false,
bool trans_backend = true,
bool trans_layout = true)
: stop_transform_(stop_transform),
trans_data_type_(trans_dtype),
trans_backend_(trans_backend),
trans_layout_(trans_layout) {}
bool NeedTransform() const {
return !stop_transform_ &&
(trans_data_type_ || trans_backend_ || trans_layout_);
}
bool need_trans_data_type() const {
return !stop_transform_ && trans_data_type_;
}
bool need_trans_backend() const { return !stop_transform_ && trans_backend_; }
bool need_trans_layout() const { return !stop_transform_ && trans_layout_; }
private:
// This is the highest priority in flags,
// and can be setted by api[data_transform->skip_transform] in the yaml file.
bool stop_transform_ = false;
// trans_data_type_ can be setted by api[data_transform->support_trans_dtype]
// in the yaml file.
// trans_data_type_ only affect the non complex types,
// the complex is always transferd, except stop_transform_ is true.
bool trans_data_type_ = false;
// trans_backend_ and trans_layout_ are true defalutly,
// and they can only be setted by global flag.
bool trans_backend_ = true;
bool trans_layout_ = true;
};
std::shared_ptr<pten::DenseTensor> PrepareData(
const Tensor& input,
const pten::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
std::unique_ptr<std::vector<pten::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const pten::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
} // namespace experimental
} // namespace paddle
......@@ -70,9 +70,9 @@ const Kernel& KernelFactory::SelectKernelOrThrowError(
auto kernel_iter = iter->second.find(kernel_key);
// TODO(chenweihang): polish refind impl here
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != pten::DataLayout::ANY) {
kernel_key.layout() != pten::DataLayout::ALL_LAYOUT) {
pten::KernelKey any_layout_kernel_key(
kernel_key.backend(), pten::DataLayout::ANY, kernel_key.dtype());
kernel_key.backend(), pten::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}
PADDLE_ENFORCE_NE(
......
......@@ -234,6 +234,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
......
......@@ -307,6 +307,14 @@ void ReduceInferMeta(const MetaTensor& x,
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out);
}
void TransferLayoutInferMeta(const MetaTensor& x,
DataLayout layout,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
out->set_layout(layout);
}
} // namespace pten
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);
......@@ -69,4 +69,9 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype,
bool keep_dim,
MetaTensor* out);
void TransferLayoutInferMeta(const MetaTensor& x,
DataLayout layout,
MetaTensor* out);
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/transfer_layout_kernel.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/transpose.h"
namespace pten {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
PADDLE_ENFORCE_NE(
from,
to,
pten::errors::InvalidArgument(
"Layout transform should transform between different layout."));
if (from == DataLayout::NCHW && to == DataLayout::NHWC) {
return {0, 2, 3, 1};
} else if (from == DataLayout::NHWC && to == DataLayout::NCHW) {
return {0, 3, 1, 2};
} else {
PADDLE_THROW(
pten::errors::InvalidArgument("Unsupported layout transform."));
}
}
template <typename T, typename Context>
void CastDataLayout(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
math::Transpose<Context, T, 4> trans4;
trans4(dev_ctx, x, out, axis);
}
template <typename Context>
void TransferLayoutKernel(const Context& dev_ctx,
const DenseTensor& x,
DataLayout dst_layout,
DenseTensor* out) {
auto src_dim = x.dims();
auto axis = GetAxis(x.layout(), dst_layout);
std::vector<int64_t> dst_dim;
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}
out->ResizeAndAllocate(framework::make_ddim(dst_dim));
PD_VISIT_ALL_TYPES(x.dtype(), "CastDataLayout", ([&] {
CastDataLayout<data_t, Context>(dev_ctx, x, axis, out);
}));
}
} // namespace pten
PT_REGISTER_GENERAL_KERNEL(pten_transfer_layout,
CPU,
ALL_LAYOUT,
pten::TransferLayoutKernel<pten::CPUContext>,
ALL_DTYPE) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename Context>
void TransferLayoutKernel(const Context& dev_ctx,
const DenseTensor& x,
DataLayout dst_layout,
DenseTensor* out);
template <typename Context>
DenseTensor TransferLayout(const Context& dev_ctx,
const DenseTensor& x,
DataLayout dst_layout) {
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
{x.dtype(), x.dims(), dst_layout});
MetaTensor meta_out(&dense_out);
TransferLayoutInferMeta(x, dst_layout, &meta_out);
TransferLayoutKernel<Context>(dev_ctx, x, dst_layout, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -22,4 +22,6 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_data_transform SRCS test_data_transform.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
namespace paddle {
namespace tests {
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, data_transform_same_place) {
// 1. create tensor
auto x = paddle::experimental::full({3, 3},
1.0,
experimental::DataType::COMPLEX128,
experimental::Backend::CPU);
auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT32, experimental::Backend::CPU);
std::vector<pten::dtype::complex<double>> sum(9, 6.0);
// 2. test API
auto out = paddle::experimental::matmul(x, y, false, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
ASSERT_EQ(out.type(), pten::DataType::COMPLEX128);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
for (size_t i = 0; i < 9; i++) {
ASSERT_NEAR(sum[i].real,
dense_out->data<pten::dtype::complex<double>>()[i].real,
1e-6f);
ASSERT_NEAR(sum[i].imag,
dense_out->data<pten::dtype::complex<double>>()[i].imag,
1e-6f);
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(Tensor, data_transform_diff_place) {
// 1. create tensor
auto x = paddle::experimental::full(
{3, 3}, 1.0, experimental::DataType::FLOAT64, experimental::Backend::CPU);
auto y = paddle::experimental::full(
{3, 3}, 2.0, experimental::DataType::FLOAT64, experimental::Backend::GPU);
std::vector<float> sum(9, 6.0);
// 2. test API
auto out = paddle::experimental::matmul(x, y, false, false);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 3);
ASSERT_EQ(out.dims()[1], 3);
ASSERT_EQ(out.numel(), 9);
ASSERT_EQ(out.dtype(), pten::DataType::FLOAT64);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
ASSERT_EQ(out.impl()->place(),
pten::TransToFluidPlace(experimental::Backend::GPU));
auto ref_out = experimental::copy_to(out, experimental::Backend::CPU, true);
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(ref_out.impl());
for (size_t i = 0; i < 9; i++) {
ASSERT_NEAR(sum[i], dense_out->data<double>()[i], 1e-6f);
}
}
#endif
} // namespace tests
} // namespace paddle
......@@ -58,6 +58,18 @@ class API:
if 'param' not in self.infer_meta:
self.infer_meta['param'] = None
self.data_transform = {
'skip_transform': [],
'support_trans_dtype': []
}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
self.data_transform['skip_transform'] = api_item_yaml[
'data_transform']['skip_transform']
if 'support_trans_dtype' in api_item_yaml['data_transform']:
self.data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform']['support_trans_dtype']
def gene_api_declaration(self):
return f"""
PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
......@@ -97,7 +109,7 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
if self.is_base_api:
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param'])
self.kernel['param'], self.data_transform)
outputs_args, output_names, output_create = self.gene_output(
self.out_type_list)
return f"""
......@@ -143,6 +155,7 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
......
......@@ -50,6 +50,20 @@ class BackwardAPI:
'param']) == 0:
self.infer_meta['param'] = None
self.data_transform = {
'skip_transform': [],
'support_trans_dtype': []
}
if 'data_transform' in backward_item_yaml:
if 'skip_transform' in backward_item_yaml['data_transform']:
self.data_transform['skip_transform'] = backward_item_yaml[
'data_transform']['skip_transform']
if 'support_trans_dtype' in backward_item_yaml[
'data_transform']:
self.data_transform[
'support_trans_dtype'] = backward_item_yaml[
'data_transform']['support_trans_dtype']
def parse_forward_config(self, forward_config):
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
result = re.search(
......@@ -144,7 +158,7 @@ class BackwardAPI:
if self.is_base_api:
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param'])
self.kernel['param'], self.data_transform)
outputs_args, output_names, output_create = self.gene_output(
self.output_type_list)
return f"""
......@@ -208,6 +222,7 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
......
......@@ -296,7 +296,7 @@ def gene_infer_meta(input_names, attr_names, output_names, infer_meta) -> str:
"""
def get_kernel_args(inputs, attrs, out_type_list, kernel_param):
def get_kernel_args(inputs, attrs, out_type_list, kernel_param, data_transform):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
......@@ -321,6 +321,22 @@ def get_kernel_args(inputs, attrs, out_type_list, kernel_param):
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
if input_name in kernel_param:
trans_flag = "{}"
if input_name in data_transform['skip_transform']:
trans_flag = "{true}"
elif input_name in data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
else:
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
if param in input_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册