未验证 提交 50d5bf79 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Change input vec tensor to pointer type (#40078)

* change input vec tensor to pointer

* update input between

* fix format error

* resolve conflict

* resolve conflict
上级 d2a911b4
...@@ -70,7 +70,7 @@ using ValueVariantType = ...@@ -70,7 +70,7 @@ using ValueVariantType =
backends::CpuPhiAllocator, backends::CpuPhiAllocator,
backends::CpuPhiContext, backends::CpuPhiContext,
::phi::CPUContext, ::phi::CPUContext,
std::vector<phi::DenseTensor>, std::vector<const phi::DenseTensor*>,
paddle::experimental::ScalarBase<phi::DenseTensor>, paddle::experimental::ScalarBase<phi::DenseTensor>,
paddle::experimental::ScalarArrayBase<phi::DenseTensor>, paddle::experimental::ScalarArrayBase<phi::DenseTensor>,
std::vector<phi::MetaTensor*>, std::vector<phi::MetaTensor*>,
......
...@@ -71,11 +71,11 @@ paddle::optional<phi::MetaTensor> MakeMetaTensor( ...@@ -71,11 +71,11 @@ paddle::optional<phi::MetaTensor> MakeMetaTensor(
} }
std::vector<phi::MetaTensor> MakeMetaTensor( std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor>& tensors) { const std::vector<const phi::DenseTensor*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors; std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size()); meta_tensors.reserve(tensors.size());
for (const auto& t : tensors) { for (const auto* t : tensors) {
meta_tensors.emplace_back(t); meta_tensors.emplace_back(*t);
} }
return meta_tensors; return meta_tensors;
} }
......
...@@ -51,7 +51,7 @@ paddle::optional<phi::MetaTensor> MakeMetaTensor( ...@@ -51,7 +51,7 @@ paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::DenseTensor&>& tensor); const paddle::optional<const phi::DenseTensor&>& tensor);
std::vector<phi::MetaTensor> MakeMetaTensor( std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor>& tensors); const std::vector<const phi::DenseTensor*>& tensors);
phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor); phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor);
......
...@@ -82,12 +82,11 @@ class KernelContext { ...@@ -82,12 +82,11 @@ class KernelContext {
} }
template <typename TensorType> template <typename TensorType>
std::vector<TensorType> MoveInputsBetween(size_t start, size_t end) { std::vector<const TensorType*> InputsBetween(size_t start, size_t end) {
std::vector<TensorType> v; std::vector<const TensorType*> v;
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
auto t = static_cast<const TensorType*>(inputs_.at(i)); auto* t = static_cast<const TensorType*>(inputs_.at(i));
v.emplace_back(*t); v.emplace_back(t);
inputs_[i] = nullptr;
} }
return v; return v;
} }
......
...@@ -87,8 +87,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -87,8 +87,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == } else if (arg_type == std::type_index(typeid(
std::type_index(typeid(const std::vector<DenseTensor>&))) { const std::vector<const DenseTensor*>&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
......
...@@ -102,26 +102,26 @@ namespace phi { ...@@ -102,26 +102,26 @@ namespace phi {
} \ } \
} }
#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ #define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \
template <typename... Tail> \ template <typename... Tail> \
struct KernelCallHelper<const std::vector<tensor_type>&, Tail...> { \ struct KernelCallHelper<const std::vector<const tensor_type*>&, Tail...> { \
template <int dev_ctx_idx, \ template <int dev_ctx_idx, \
int in_idx, \ int in_idx, \
int attr_idx, \ int attr_idx, \
int out_idx, \ int out_idx, \
typename... PreviousArgs> \ typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(attr_idx == 0, \ static_assert(attr_idx == 0, \
"Kernel's Input should appear before Attributes."); \ "Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \ static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \ "Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \ const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
std::vector<tensor_type> arg = std::move( \ std::vector<const tensor_type*> arg = std::move( \
ctx->MoveInputsBetween<tensor_type>(range.first, range.second)); \ ctx->InputsBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \ KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \ template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg); \ ctx, pargs..., arg); \
} \ } \
} }
#define PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ #define PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \
......
...@@ -21,7 +21,7 @@ namespace phi { ...@@ -21,7 +21,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BroadcastTensorsGradKernel(const Context& ctx, void BroadcastTensorsGradKernel(const Context& ctx,
const std::vector<DenseTensor>& dout, const std::vector<const DenseTensor*>& dout,
std::vector<DenseTensor*> dx); std::vector<DenseTensor*> dx);
} // namespace phi } // namespace phi
...@@ -21,7 +21,7 @@ namespace phi { ...@@ -21,7 +21,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BroadcastTensorsKernel(const Context& ctx, void BroadcastTensorsKernel(const Context& ctx,
const std::vector<DenseTensor>& x, const std::vector<const DenseTensor*>& x,
std::vector<DenseTensor*> out); std::vector<DenseTensor*> out);
} // namespace phi } // namespace phi
...@@ -22,19 +22,19 @@ namespace phi { ...@@ -22,19 +22,19 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ConcatKernel(const Context& dev_ctx, void ConcatKernel(const Context& dev_ctx,
const std::vector<DenseTensor>& x, const std::vector<const DenseTensor*>& x,
const Scalar& axis, const Scalar& axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx, DenseTensor Concat(const Context& dev_ctx,
const std::vector<DenseTensor>& x, const std::vector<const DenseTensor*>& x,
const Scalar& axis) { const Scalar& axis) {
std::vector<MetaTensor> meta_x; std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size()); meta_x.reserve(x.size());
std::vector<MetaTensor*> meta_x_ptr; std::vector<MetaTensor*> meta_x_ptr;
for (const auto& t : x) { for (const auto* t : x) {
meta_x.emplace_back(t); meta_x.emplace_back(*t);
meta_x_ptr.push_back(&meta_x.back()); meta_x_ptr.push_back(&meta_x.back());
} }
......
...@@ -59,7 +59,7 @@ namespace phi { ...@@ -59,7 +59,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BroadcastTensorsGradKernel(const Context& ctx, void BroadcastTensorsGradKernel(const Context& ctx,
const std::vector<DenseTensor>& dout, const std::vector<const DenseTensor*>& dout,
std::vector<DenseTensor*> dx) { std::vector<DenseTensor*> dx) {
// Find reduce dimensions // Find reduce dimensions
const auto& in_tensors = dout; const auto& in_tensors = dout;
...@@ -85,7 +85,7 @@ void BroadcastTensorsGradKernel(const Context& ctx, ...@@ -85,7 +85,7 @@ void BroadcastTensorsGradKernel(const Context& ctx,
// For each In-Out tensor pair, // For each In-Out tensor pair,
// Prepare and apply broadcast dims array // Prepare and apply broadcast dims array
for (size_t i = 0; i < num_ins; i++) { for (size_t i = 0; i < num_ins; i++) {
const auto* input_tensor = &in_tensors[i]; const auto* input_tensor = in_tensors[i];
auto* output_tensor = out_tensors[i]; auto* output_tensor = out_tensors[i];
const auto& input_dims = input_tensor->dims(); const auto& input_dims = input_tensor->dims();
......
...@@ -29,17 +29,17 @@ namespace phi { ...@@ -29,17 +29,17 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ConcatKernel(const Context& dev_ctx, void ConcatKernel(const Context& dev_ctx,
const std::vector<DenseTensor>& x, const std::vector<const DenseTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
DenseTensor* out) { DenseTensor* out) {
int64_t axis = axis_scalar.to<int64_t>(); int64_t axis = axis_scalar.to<int64_t>();
axis = phi::funcs::ComputeAxis(axis, x[0].dims().size()); axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size());
std::vector<phi::DDim> x_dims; std::vector<phi::DDim> x_dims;
x_dims.reserve(x.size()); x_dims.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i) { for (size_t i = 0; i < x.size(); ++i) {
x_dims.push_back(x[i].dims()); x_dims.push_back(x[i]->dims());
} }
phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis); phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis);
...@@ -47,13 +47,13 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -47,13 +47,13 @@ void ConcatKernel(const Context& dev_ctx,
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
// If axis is 0, the lod of the output is not the same as inputs. // If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0].lod().size() > 0) { if (axis == 0 && x[0]->lod().size() > 0) {
size_t lod_size_0 = x[0].lod().size(); size_t lod_size_0 = x[0]->lod().size();
size_t lod_size = lod_size_0; size_t lod_size = lod_size_0;
for (size_t i = 1; i < x.size(); ++i) { for (size_t i = 1; i < x.size(); ++i) {
if (x[i].lod().size() > 0) { if (x[i]->lod().size() > 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x[i].lod().size(), x[i]->lod().size(),
lod_size_0, lod_size_0,
phi::errors::Unimplemented( phi::errors::Unimplemented(
"The lod level of all input LoDTensors should be same. " "The lod level of all input LoDTensors should be same. "
...@@ -61,7 +61,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -61,7 +61,7 @@ void ConcatKernel(const Context& dev_ctx,
"it is not supported currently. The lod level of %dth input " "it is not supported currently. The lod level of %dth input "
"is %d and first input is %d.", "is %d and first input is %d.",
i, i,
x[i].lod().size(), x[i]->lod().size(),
lod_size_0)); lod_size_0));
} else { } else {
lod_size = 0; lod_size = 0;
...@@ -71,7 +71,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -71,7 +71,7 @@ void ConcatKernel(const Context& dev_ctx,
if (lod_size) { if (lod_size) {
auto* out_lod = out->mutable_lod(); auto* out_lod = out->mutable_lod();
for (size_t i = 1; i < x.size(); ++i) { for (size_t i = 1; i < x.size(); ++i) {
auto in_lod = phi::ConvertToLengthBasedLoD(x[i].lod()); auto in_lod = phi::ConvertToLengthBasedLoD(x[i]->lod());
phi::AppendLoD(out_lod, in_lod); phi::AppendLoD(out_lod, in_lod);
} }
} }
...@@ -80,28 +80,29 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -80,28 +80,29 @@ void ConcatKernel(const Context& dev_ctx,
// Sometimes direct copies will be faster, this maybe need deeply analysis. // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && x.size() < 10) { if (axis == 0 && x.size() < 10) {
size_t output_offset = 0; size_t output_offset = 0;
for (auto& in : x) { for (const auto* in : x) {
if (in.numel() == 0UL) { if (in->numel() == 0UL) {
continue; continue;
} }
auto in_stride = phi::stride_numel(in.dims()); auto in_stride = phi::stride_numel(in->dims());
auto out_stride = phi::stride_numel(out->dims()); auto out_stride = phi::stride_numel(out->dims());
paddle::operators::StridedNumelCopyWithAxis<T>( paddle::operators::StridedNumelCopyWithAxis<T>(
dev_ctx, dev_ctx,
axis, axis,
out->data<T>() + output_offset, out->data<T>() + output_offset,
out_stride, out_stride,
in.data<T>(), in->data<T>(),
in_stride, in_stride,
in_stride[axis]); in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
} }
} else { } else {
// TODO(chenweihang): concat functor support vector<DenseTensor*> input
std::vector<phi::DenseTensor> inputs; std::vector<phi::DenseTensor> inputs;
inputs.reserve(x.size()); inputs.reserve(x.size());
for (size_t j = 0; j < x.size(); ++j) { for (size_t j = 0; j < x.size(); ++j) {
if (x[j].numel() > 0) { if (x[j]->numel() > 0) {
inputs.emplace_back(x[j]); inputs.emplace_back(*x[j]);
} else { } else {
continue; continue;
} }
......
...@@ -27,7 +27,7 @@ namespace phi { ...@@ -27,7 +27,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void BroadcastTensorsGradKernel(const Context& ctx, void BroadcastTensorsGradKernel(const Context& ctx,
const std::vector<DenseTensor>& dout, const std::vector<const DenseTensor*>& dout,
std::vector<DenseTensor*> dx) { std::vector<DenseTensor*> dx) {
// Find reduce dimensions // Find reduce dimensions
const auto& in_tensors = dout; const auto& in_tensors = dout;
...@@ -54,7 +54,7 @@ void BroadcastTensorsGradKernel(const Context& ctx, ...@@ -54,7 +54,7 @@ void BroadcastTensorsGradKernel(const Context& ctx,
// For each In-Out tensor pair, // For each In-Out tensor pair,
// Prepare and apply broadcast dims array // Prepare and apply broadcast dims array
for (size_t i = 0; i < num_ins; i++) { for (size_t i = 0; i < num_ins; i++) {
auto* input_tensor = &in_tensors[i]; auto* input_tensor = in_tensors[i];
auto* output_tensor = out_tensors[i]; auto* output_tensor = out_tensors[i];
const DDim& input_dims = input_tensor->dims(); const DDim& input_dims = input_tensor->dims();
......
...@@ -29,16 +29,16 @@ namespace phi { ...@@ -29,16 +29,16 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ConcatKernel(const Context& dev_ctx, void ConcatKernel(const Context& dev_ctx,
const std::vector<DenseTensor>& x, const std::vector<const DenseTensor*>& x,
const Scalar& axis_scalar, const Scalar& axis_scalar,
DenseTensor* out) { DenseTensor* out) {
int64_t axis = axis_scalar.to<int64_t>(); int64_t axis = axis_scalar.to<int64_t>();
axis = phi::funcs::ComputeAxis(axis, x[0].dims().size()); axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size());
std::vector<phi::DDim> x_dims; std::vector<phi::DDim> x_dims;
for (size_t i = 0; i < x.size(); ++i) { for (size_t i = 0; i < x.size(); ++i) {
x_dims.push_back(x[i].dims()); x_dims.push_back(x[i]->dims());
} }
phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis); phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis);
...@@ -46,13 +46,13 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -46,13 +46,13 @@ void ConcatKernel(const Context& dev_ctx,
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
// If axis is 0, the lod of the output is not the same as inputs. // If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0].lod().size() > 0) { if (axis == 0 && x[0]->lod().size() > 0) {
size_t lod_size_0 = x[0].lod().size(); size_t lod_size_0 = x[0]->lod().size();
size_t lod_size = lod_size_0; size_t lod_size = lod_size_0;
for (size_t i = 1; i < x.size(); ++i) { for (size_t i = 1; i < x.size(); ++i) {
if (x[i].lod().size() > 0) { if (x[i]->lod().size() > 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x[i].lod().size(), x[i]->lod().size(),
lod_size_0, lod_size_0,
phi::errors::Unimplemented( phi::errors::Unimplemented(
"The lod level of all input LoDTensors should be same. " "The lod level of all input LoDTensors should be same. "
...@@ -60,7 +60,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -60,7 +60,7 @@ void ConcatKernel(const Context& dev_ctx,
"it is not supported currently. The lod level of %dth input " "it is not supported currently. The lod level of %dth input "
"is %d and first input is %d.", "is %d and first input is %d.",
i, i,
x[i].lod().size(), x[i]->lod().size(),
lod_size_0)); lod_size_0));
} else { } else {
lod_size = 0; lod_size = 0;
...@@ -70,7 +70,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -70,7 +70,7 @@ void ConcatKernel(const Context& dev_ctx,
if (lod_size) { if (lod_size) {
auto* out_lod = out->mutable_lod(); auto* out_lod = out->mutable_lod();
for (size_t i = 1; i < x.size(); ++i) { for (size_t i = 1; i < x.size(); ++i) {
auto in_lod = phi::ConvertToLengthBasedLoD(x[i].lod()); auto in_lod = phi::ConvertToLengthBasedLoD(x[i]->lod());
phi::AppendLoD(out_lod, in_lod); phi::AppendLoD(out_lod, in_lod);
} }
} }
...@@ -79,18 +79,18 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -79,18 +79,18 @@ void ConcatKernel(const Context& dev_ctx,
// Sometimes direct copies will be faster, this maybe need deeply analysis. // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && x.size() < 10) { if (axis == 0 && x.size() < 10) {
size_t output_offset = 0; size_t output_offset = 0;
for (auto& in : x) { for (auto* in : x) {
if (in.numel() == 0UL) { if (in->numel() == 0UL) {
continue; continue;
} }
auto in_stride = phi::stride_numel(in.dims()); auto in_stride = phi::stride_numel(in->dims());
auto out_stride = phi::stride_numel(out->dims()); auto out_stride = phi::stride_numel(out->dims());
paddle::operators::StridedNumelCopyWithAxis<T>( paddle::operators::StridedNumelCopyWithAxis<T>(
dev_ctx, dev_ctx,
axis, axis,
out->data<T>() + output_offset, out->data<T>() + output_offset,
out_stride, out_stride,
in.data<T>(), in->data<T>(),
in_stride, in_stride,
in_stride[axis]); in_stride[axis]);
output_offset += in_stride[axis]; output_offset += in_stride[axis];
...@@ -98,8 +98,8 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -98,8 +98,8 @@ void ConcatKernel(const Context& dev_ctx,
} else { } else {
std::vector<phi::DenseTensor> inputs; std::vector<phi::DenseTensor> inputs;
for (size_t j = 0; j < x.size(); ++j) { for (size_t j = 0; j < x.size(); ++j) {
if (x[j].numel() > 0) { if (x[j]->numel() > 0) {
inputs.push_back(x[j]); inputs.push_back(*x[j]);
} else { } else {
continue; continue;
} }
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#define SWITCH_OUT_RANK_CASE(n) \ #define SWITCH_OUT_RANK_CASE(n) \
case n: { \ case n: { \
ApplyBroadcast<T, Context, n>(ctx, &in_tensors[i], out_tensors[i]); \ ApplyBroadcast<T, Context, n>(ctx, in_tensors[i], out_tensors[i]); \
break; \ break; \
} }
namespace phi { namespace phi {
...@@ -75,7 +75,7 @@ void ApplyBroadcast(const Context& ctx, ...@@ -75,7 +75,7 @@ void ApplyBroadcast(const Context& ctx,
template <typename T, typename Context> template <typename T, typename Context>
void BroadcastTensorsKernel(const Context& ctx, void BroadcastTensorsKernel(const Context& ctx,
const std::vector<DenseTensor>& x, const std::vector<const DenseTensor*>& x,
std::vector<DenseTensor*> out) { std::vector<DenseTensor*> out) {
const auto& in_tensors = x; const auto& in_tensors = x;
auto out_tensors = out; auto out_tensors = out;
......
...@@ -43,7 +43,7 @@ template <typename T, typename Context> ...@@ -43,7 +43,7 @@ template <typename T, typename Context>
void FakeDot(const Context& dev_ctx, void FakeDot(const Context& dev_ctx,
const phi::DenseTensor& x, const phi::DenseTensor& x,
const phi::DenseTensor& y, const phi::DenseTensor& y,
const std::vector<phi::DenseTensor>& fake_input_vec, const std::vector<const phi::DenseTensor*>& fake_input_vec,
bool fake_attr_bool, bool fake_attr_bool,
int fake_attr_int, int fake_attr_int,
float fake_attr_float, float fake_attr_float,
......
...@@ -53,7 +53,7 @@ TEST(DEV_API, concat) { ...@@ -53,7 +53,7 @@ TEST(DEV_API, concat) {
} }
} }
std::vector<phi::DenseTensor> inputs = {dense_x, dense_y}; std::vector<const phi::DenseTensor*> inputs = {&dense_x, &dense_y};
// 2. test API // 2. test API
phi::CPUContext dev_ctx; phi::CPUContext dev_ctx;
......
...@@ -458,7 +458,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -458,7 +458,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
elif self.inputs['input_info'][ elif self.inputs['input_info'][
param] == "const std::vector<Tensor>&": param] == "const std::vector<Tensor>&":
meta_tensor_code = meta_tensor_code + f""" meta_tensor_code = meta_tensor_code + f"""
{code_indent} auto {param}_meta_vec = MakeMetaTensor(*{PREFIX_TENSOR_NAME}{param}); {code_indent} auto {param}_meta_vec = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
{code_indent} std::vector<phi::MetaTensor*> {param}_metas({param}_meta_vec.size()); {code_indent} std::vector<phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
{code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{ {code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent} {param}_metas[i] = &{param}_meta_vec[i]; {code_indent} {param}_metas[i] = &{param}_meta_vec[i];
...@@ -502,7 +502,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -502,7 +502,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
input_trans_map = { input_trans_map = {
'const Tensor&': 'const phi::DenseTensor&', 'const Tensor&': 'const phi::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<Tensor>&':
'const std::vector<phi::DenseTensor>&', 'const std::vector<const phi::DenseTensor*>&',
'const paddle::optional<Tensor>&': 'const paddle::optional<Tensor>&':
'paddle::optional<const phi::DenseTensor&>', 'paddle::optional<const phi::DenseTensor&>',
'const paddle::optional<std::vector<Tensor>>&': 'const paddle::optional<std::vector<Tensor>>&':
...@@ -539,9 +539,22 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -539,9 +539,22 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
{code_indent} }}""" {code_indent} }}"""
else: else:
input_tensor_code = input_tensor_code + f""" if self.inputs['input_info'][input_name] == "const Tensor&":
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
elif self.inputs['input_info'][
input_name] == "const std::vector<Tensor>&":
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});
{code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent} }}"""
else:
# do nothing
pass
else: else:
if input_name in self.optional_vars: if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
...@@ -561,7 +574,14 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -561,7 +574,14 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
if param in self.optional_vars: if param in self.optional_vars:
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", " kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", "
else: else:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", " if self.inputs['input_info'][param] == "const Tensor&":
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
elif self.inputs['input_info'][
input_name] == "const std::vector<Tensor>&":
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", "
else:
# do nothing
pass
kernel_in_type = input_trans_map[input_infos[param]] kernel_in_type = input_trans_map[input_infos[param]]
kernel_args_type_list.append(kernel_in_type) kernel_args_type_list.append(kernel_in_type)
elif param in attr_names: elif param in attr_names:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册