未验证 提交 94b31f90 编写于 作者: C chentianyu03 提交者: GitHub

[pten] add optional type for infermeta (#39848)

* modify infershape by args_def

* add optional type for infermate

* add optional type for infermate

* add optional type for infermate

* support scalar type

* change OptionalInputAt function to none template

* support phi::DataType
上级 dd2c997d
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/convert_utils.h"
......@@ -376,47 +377,101 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::Scalar))) {
if (ctx->HasAttr(attr_name)) {
// TODO(chentianyu03): support other attrs later
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(float, attr)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(std::string, attr)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(int, attr)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"InferMetaContext.",
attr_name));
}
} else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name);
if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePtenScalarFromVar(*var)));
} else {
phi::Scalar tensor_scalar(-1);
tensor_scalar.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input.size() when cast op attribute `%s` to Scalar, "
"expected 1, but actually is %d .",
attr_name, infershape_input.size()));
}
}
} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) == std::type_index(typeid(int))) {
} else if (attr_defs[i].type_index == std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(float))) {
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (std::type_index(attr.type()) ==
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (std::type_index(attr.type()) ==
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<bool>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr));
} else if (std::type_index(attr.type()) ==
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else if (std::type_index(attr.type()) ==
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
infer_meta_context.EmplaceBackAttr(vector_int64_attr);
} else {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<float>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr));
} else if (std::type_index(attr.type()) ==
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<double>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<double>, attr));
} else if (std::type_index(attr.type()) ==
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<std::string>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
infer_meta_context.EmplaceBackAttr(data_type);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported attribute type is received when call "
......
......@@ -67,6 +67,14 @@ const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx);
}
paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt(
size_t idx) const {
const auto& input = inputs_.at(idx);
return input ? paddle::optional<const phi::MetaTensor&>{static_cast<
const phi::MetaTensor&>(*input)}
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}
std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
......
......@@ -51,6 +51,9 @@ class InferMetaContext {
const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor> MutableOutputBetween(size_t start, size_t end);
......@@ -135,6 +138,24 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};
template <typename... Tail>
struct InferMetaFnCallHelper<paddle::optional<const MetaTensor&>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
auto arg = ctx->OptionalInputAt(range.first);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册