未验证 提交 e8d45583 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Support Multi Input and Output for InferShape (#39870)

* add multi input for infer_shape

* support multi output for infershape

* fix split bug

* fix bug of concat

* support vector<MetaTensor*> in infrt

* fix bug
上级 8c237973
......@@ -308,22 +308,25 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
// TODO(chenweihang): support multiple inputs and outputs later
phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
if (ctx->HasInput(in_name)) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
if (ctx->HasInputs(in_name)) {
auto input_var = ctx->GetInputVarPtrs(in_name);
if (input_var.size() == 1) {
infer_meta_context.EmplaceBackInput(
std::make_shared<CompatMetaTensor>(input_var[0], ctx->IsRuntime()));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs;
inputs.reserve(input_var.size());
for (const auto& in : input_var) {
inputs.push_back(
std::make_shared<CompatMetaTensor>(in, ctx->IsRuntime()));
}
infer_meta_context.EmplaceBackInputs(std::move(inputs));
}
} else {
infer_meta_context.EmplaceBackInput({nullptr});
}
}
for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i];
......@@ -348,13 +351,13 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 0;
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
for (size_t i = 0; i < infershape_inputs.size(); ++i) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}
int64_t num_ele = 0;
if (vars.size() == 1) {
num_ele = 1;
const auto& tensor_dims = vars[0]->GetShape();
......@@ -362,16 +365,7 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
num_ele *= tensor_dims[i];
}
} else {
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
PADDLE_ENFORCE_EQ(tensor_dims.size(), 1,
platform::errors::InvalidArgument(
"The shape is constructed by multi-tensor, "
"every tensor's dims should be 1. But your "
"shape has tensor that dims is %s.",
tensor_dims.size()));
num_ele += tensor_dims[0];
}
num_ele = vars.size();
}
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
......@@ -383,10 +377,14 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(
phi::ScalarArray({BOOST_GET_CONST(int, attr)}));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
"construct InferMetaContext.",
attr_name));
}
}
......@@ -414,7 +412,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
} else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name);
if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
......@@ -490,6 +487,28 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else {
// do nothing
}
}
for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name)) {
auto output_var = ctx->GetOutputVarPtrs(out_name);
if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
output_var[0], ctx->IsRuntime()));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs;
outputs.reserve(output_var.size());
for (const auto& out : output_var) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
}
infer_meta_context.EmplaceBackOutputs(std::move(outputs));
}
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
......
......@@ -18,7 +18,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -33,41 +35,6 @@ class ConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "Concat");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Concat");
auto inputs_dims = ctx->GetInputsDim("X");
const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(
inputs_num, static_cast<size_t>(0),
platform::errors::InvalidArgument(
"The number of input tensors in concat op should > 0. But "
"received inputs' length is 0."));
if (inputs_num == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
if (ctx->HasInput("AxisTensor")) {
auto out_dims =
phi::make_ddim(std::vector<int>(inputs_dims[0].size(), -1));
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
size_t axis =
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
static_cast<int64_t>(inputs_dims[0].size()));
framework::DDim out_dims =
phi::funcs::ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims, axis);
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -237,9 +204,14 @@ class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(concat, ConcatInferShapeFunctor,
PT_INFER_META(phi::ConcatInferMeta));
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
ops::ConcatGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatGradOpMaker<paddle::imperative::OpBase>);
ops::ConcatGradOpMaker<paddle::imperative::OpBase>,
ConcatInferShapeFunctor);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/split_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using framework::Tensor;
......@@ -23,52 +26,6 @@ class SplitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of SplitOp should not be null."));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of SplitOp should not be empty."));
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
std::vector<int> sections = static_cast<std::vector<int>>(
ctx->Attrs().Get<std::vector<int>>("sections"));
const size_t outs_number = outs_names.size();
if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(
sections.size(), outs_number,
platform::errors::InvalidArgument("tensor split sections size "
"should be equal to output size."));
}
if (ctx->HasInput("AxisTensor")) {
auto out_dims = phi::make_ddim(std::vector<int>(in_dims.size(), -1));
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
return;
}
bool each_section_is_known =
(sections.size() > 0 && !ctx->HasInputs("SectionsTensorList"));
auto outs_dims = UpdateOutsDims(ctx->IsRuntime(), each_section_is_known,
in_dims, num, sections, axis, outs_number);
ctx->SetOutputsDim("Out", outs_dims);
if (axis != 0) {
// Only pass LoD when not spliting along the first dim.
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -168,6 +125,10 @@ Example:
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(split, SplitInferShapeFunctor,
PT_INFER_META(phi::SplitInferMeta));
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
ops::SplitGradMaker<paddle::imperative::OpBase>,
SplitInferShapeFunctor);
......@@ -73,7 +73,7 @@ using ValueVariantType =
std::vector<phi::DenseTensor>,
paddle::experimental::ScalarBase<phi::DenseTensor>,
paddle::experimental::ScalarArrayBase<phi::DenseTensor>,
std::vector<phi::MetaTensor>,
std::vector<phi::MetaTensor*>,
phi::MetaConfig,
paddle::experimental::Backend,
paddle::experimental::DataLayout,
......
......@@ -94,12 +94,16 @@ std::vector<Tensor> split_impl(const Tensor& x,
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}
phi::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);
MakeMetaTensor(*dense_x), num_or_sections, axis, meta_out_ptrs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
......
......@@ -75,13 +75,13 @@ paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt(
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}
std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(*inputs_.at(i));
result.push_back(inputs_.at(i).get());
}
return result;
......@@ -91,12 +91,12 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}
std::vector<MetaTensor> InferMetaContext::MutableOutputBetween(size_t start,
size_t end) {
std::vector<MetaTensor> result;
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
size_t end) {
std::vector<MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(*outputs_.at(i));
result.emplace_back(outputs_.at(i).get());
}
return result;
}
......
......@@ -50,13 +50,13 @@ class InferMetaContext {
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor> MutableOutputBetween(size_t start, size_t end);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end);
template <typename AttrType>
AttrType AttrAt(size_t idx) {
......@@ -157,7 +157,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
struct InferMetaFnCallHelper<const std::vector<MetaTensor*>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
......@@ -165,7 +165,7 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor> arg =
std::vector<MetaTensor*> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
......@@ -210,13 +210,12 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};
template <typename... Tail>
struct InferMetaFnCallHelper<std::vector<MetaTensor>*, Tail...> {
struct InferMetaFnCallHelper<std::vector<MetaTensor*>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
std::vector<MetaTensor> tmp =
std::vector<MetaTensor*> arg =
ctx->MutableOutputBetween(range.first, range.second);
std::vector<MetaTensor>* arg = &tmp;
InferMetaFnCallHelper<
Tail...>::template Call<in_idx, attr_idx, out_idx + 1>(ctx,
pargs...,
......
......@@ -84,7 +84,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void ConcatInferMeta(const std::vector<MetaTensor>& x,
void ConcatInferMeta(const std::vector<MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
MetaConfig config) {
......@@ -93,10 +93,19 @@ void ConcatInferMeta(const std::vector<MetaTensor>& x,
phi::errors::InvalidArgument(
"The size of input meta vector should be greater"
"than 0."));
if (axis_scalar.FromTensor()) {
auto out_dims =
phi::make_ddim(std::vector<int>(x.at(0)->dims().size(), -1));
out->set_dims(out_dims);
out->set_dtype(x.at(0)->dtype());
out->set_layout(x.at(0)->layout());
out->share_lod(*x.at(0));
return;
}
int axis = axis_scalar.to<int>();
// 1. calculate axis
int rank = x.at(0).dims().size();
int rank = x.at(0)->dims().size();
PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank,
true,
......@@ -111,15 +120,17 @@ void ConcatInferMeta(const std::vector<MetaTensor>& x,
// 2. calculate out dims
std::vector<phi::DDim> x_dims;
for (auto& x_t : x) {
x_dims.push_back(x_t.dims());
x_dims.reserve(x.size());
for (const auto* x_t : x) {
x_dims.emplace_back(x_t->dims());
}
phi::DDim out_dim =
phi::funcs::ComputeAndCheckShape(config.is_runtime, x_dims, axis);
out->set_dims(out_dim);
out->set_dtype(x.at(0).dtype());
out->set_layout(x.at(0).layout());
out->set_dtype(x.at(0)->dtype());
out->set_layout(x.at(0)->layout());
out->share_lod(*x.at(0));
}
} // namespace phi
......@@ -25,7 +25,7 @@ void BilinearTensorProductInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ConcatInferMeta(const std::vector<MetaTensor>& x,
void ConcatInferMeta(const std::vector<MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
MetaConfig config = MetaConfig());
......
......@@ -459,8 +459,19 @@ void TransferLayoutInferMeta(const MetaTensor& x,
void SplitInferMeta(const MetaTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<MetaTensor>* out,
std::vector<MetaTensor*> out,
MetaConfig config) {
if (!config.is_runtime) {
if (axis.FromTensor() || num_or_sections.FromTensor()) {
auto out_dims = phi::make_ddim(std::vector<int>(x.dims().size(), -1));
for (auto* item : out) {
item->set_dims(out_dims);
item->share_lod(x);
}
return;
}
}
int axis_value = axis.to<int>();
int rank = x.dims().size();
PADDLE_ENFORCE_EQ(
......@@ -475,27 +486,34 @@ void SplitInferMeta(const MetaTensor& x,
axis_value = axis_value + rank;
}
std::vector<phi::DDim> out_dims(out.size(), x.dims());
auto input_axis_dim = x.dims().at(axis_value);
auto num_or_sections_data = num_or_sections.GetData();
// step1: get formated sections
std::vector<int64_t> sections;
// num_or_sections is a number
if (num_or_sections_data.size() == 1) {
int num = num_or_sections_data.at(0);
if (config.is_runtime || input_axis_dim > 0) {
int num = num_or_sections_data.at(0);
PADDLE_ENFORCE_EQ(
input_axis_dim % num,
0,
phi::errors::InvalidArgument(
"The input's size along the split dimension "
"must be evenly divisible by Attr(num_or_sections). "
"But received Attr(num_or_sections) "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d.",
num,
x.dims(),
axis_value));
PADDLE_ENFORCE_EQ(input_axis_dim % num,
0,
phi::errors::InvalidArgument(
"The input's size along the split dimension "
"must be evenly divisible by Attr(num_or_sections). "
"But received Attr(num_or_sections) "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d.",
num,
x.dims(),
axis_value));
for (int i = 0; i < num; ++i) {
sections.push_back(input_axis_dim / num);
size_t out_axis_dim = input_axis_dim / num;
for (auto& out_dim : out_dims) {
out_dim[axis_value] = out_axis_dim;
}
} else {
for (auto& out_dim : out_dims) {
out_dim[axis_value] = -1;
}
}
} else {
// num_or_sections is a sections
......@@ -503,10 +521,9 @@ void SplitInferMeta(const MetaTensor& x,
int unknow_dim_idx = -1;
int num_of_unknow = 0;
int sum_of_section = 0;
std::vector<int64_t> sections = num_or_sections_data;
for (size_t i = 0; i < num_or_sections_data.size(); ++i) {
sections.push_back(num_or_sections_data[i]);
if (num_or_sections_data[i] == unknow_dim_val) {
num_of_unknow++;
unknow_dim_idx = i;
......@@ -558,31 +575,22 @@ void SplitInferMeta(const MetaTensor& x,
x.dims(),
axis_value));
}
}
// setp2: fill out dims
std::vector<phi::DDim> out_dims(sections.size(), x.dims());
if (config.is_runtime || input_axis_dim > 0) {
for (size_t i = 0; i < sections.size(); ++i) {
for (size_t i = 0; i < out_dims.size(); ++i) {
out_dims[i][axis_value] = sections[i];
}
} else {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = -1;
}
}
for (size_t i = 0; i < sections.size(); ++i) {
for (size_t i = 0; i < out.size(); ++i) {
if (axis_value != 0) {
// Only pass LoD when not spliting along the first dim.
(*out)[i].set_dtype(x.dtype());
(*out)[i].set_dims(out_dims[i]);
(*out)[i].set_layout(x.layout());
out.at(i)->set_dtype(x.dtype());
out.at(i)->set_dims(out_dims[i]);
out.at(i)->set_layout(x.layout());
} else {
(*out)[i].set_dtype(x.dtype());
(*out)[i].set_dims(out_dims[i]);
(*out)[i].set_layout(x.layout());
(*out)[i].share_lod(x);
out.at(i)->set_dtype(x.dtype());
out.at(i)->set_dims(out_dims[i]);
out.at(i)->set_layout(x.layout());
out.at(i)->share_lod(x);
}
}
}
......
......@@ -107,7 +107,7 @@ void TransferLayoutInferMeta(const MetaTensor& x,
void SplitInferMeta(const MetaTensor& x_meta,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<MetaTensor>* out,
std::vector<MetaTensor*> out,
MetaConfig config = MetaConfig());
void UnbindInferMeta(const MetaTensor& x,
......
......@@ -31,13 +31,16 @@ DenseTensor Concat(const Context& dev_ctx,
const std::vector<DenseTensor>& x,
const Scalar& axis) {
std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<MetaTensor*> meta_x_ptr;
for (const auto& t : x) {
meta_x.emplace_back(t);
meta_x_ptr.push_back(&meta_x.back());
}
auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
ConcatInferMeta(meta_x, axis.to<int>(), &meta_out, /*is_runtime=*/true);
ConcatInferMeta(meta_x_ptr, axis.to<int>(), &meta_out, /*is_runtime=*/true);
ConcatKernel<T, Context>(dev_ctx, x, axis, &dense_out);
return dense_out;
}
......
......@@ -37,6 +37,7 @@ void ConcatKernel(const Context& dev_ctx,
axis = phi::funcs::ComputeAxis(axis, x[0].dims().size());
std::vector<phi::DDim> x_dims;
x_dims.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i) {
x_dims.push_back(x[i].dims());
}
......@@ -97,9 +98,10 @@ void ConcatKernel(const Context& dev_ctx,
}
} else {
std::vector<phi::DenseTensor> inputs;
inputs.reserve(x.size());
for (size_t j = 0; j < x.size(); ++j) {
if (x[j].numel() > 0) {
inputs.push_back(x[j]);
inputs.emplace_back(x[j]);
} else {
continue;
}
......
......@@ -28,20 +28,6 @@ void SplitKernel(const Context& dev_ctx,
const ScalarArray& num_or_sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
}
phi::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.template Alloc<T>(outs[j]);
......
......@@ -27,20 +27,6 @@ void SplitKernel(const Context& dev_ctx,
const ScalarArray& num_or_sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
}
phi::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.template Alloc<T>(outs[j]);
......
......@@ -43,18 +43,18 @@ std::vector<DenseTensor> Split(const Context& dev_ctx,
}
std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;
out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
std::vector<DenseTensor> result;
result.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor tmp_meta(&dense_out);
result.push_back(dense_out);
out_meta.push_back(&result.back());
result.emplace_back(phi::Empty<T, Context>(dev_ctx));
out_meta.emplace_back(&result.back());
out_meta_ptr.push_back(&out_meta.back());
}
SplitInferMeta(x, num_or_sections, axis, &out_meta);
SplitInferMeta(x, num_or_sections, axis, out_meta_ptr);
std::vector<DenseTensor*> outs;
outs.reserve(out_meta.size());
......
......@@ -451,7 +451,20 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
param_code = ""
for param in infer_meta_params:
if param in input_names:
if param in self.optional_vars:
if self.inputs['input_info'][param] == "const Tensor&":
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
elif self.inputs['input_info'][
param] == "const std::vector<Tensor>&":
meta_tensor_code = meta_tensor_code + f"""
{code_indent} auto {param}_meta_vec = MakeMetaTensor(*{PREFIX_TENSOR_NAME}{param});
{code_indent} std::vector<phi::MetaTensor*> {param}_metas({param}_meta_vec.size());
{code_indent} for (size_t i = 0; i < {param}_meta_vec.size(); ++i) {{
{code_indent} {param}_metas[i] = &{param}_meta_vec[i];
{code_indent} }}
"""
param_code = param_code + param + "_metas, "
elif param in self.optional_vars:
meta_tensor_code = meta_tensor_code + f"""
{code_indent} paddle::optional<const phi::MetaTensor&> {PREFIX_TENSOR_NAME}meta_ref_{param}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}meta_{param} = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
......@@ -461,7 +474,9 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
param_code = param_code + f"{PREFIX_TENSOR_NAME}meta_ref_{param}, "
else:
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
raise ValueError(
f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in kernel_output_names:
meta_tensor_code = meta_tensor_code + code_indent + " phi::MetaTensor " + param.replace(
'kernel_', PREFIX_META_TENSOR_NAME) + "(" + param + ");\n"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册