未验证 提交 4e5d6743 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Adjust the params of creation kernel for inference (#39573)

* remove manual_api

* change sig map of full and empty

* fix fill_any_like_xpu_op

* fix fill_any_like_xpu_op

* fix problem of fill_any_like_xpu_op

* fix conflict

* polish code
上级 06b177c0
...@@ -43,7 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue( ...@@ -43,7 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue(
bool is_leaf) { bool is_leaf) {
paddle::experimental::Tensor out = paddle::experimental::full( paddle::experimental::Tensor out = paddle::experimental::full(
paddle::framework::vectorize(ddim), paddle::experimental::Scalar(value), paddle::framework::vectorize(ddim), paddle::experimental::Scalar(value),
dtype, pten::TransToPtenBackend(place), layout); dtype, pten::TransToPtenBackend(place));
auto meta = EagerUtils::autograd_meta(&out); auto meta = EagerUtils::autograd_meta(&out);
if (is_leaf) { if (is_leaf) {
......
...@@ -33,6 +33,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -33,6 +33,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
float, T>::type>::type; float, T>::type>::type;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -65,7 +66,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -65,7 +66,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
pten::FullLikeKernel<T>( pten::FullLikeKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext< static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
value, out); *x, value, pten::DataType::UNDEFINED, out);
} }
}; };
......
...@@ -31,6 +31,7 @@ class FillAnyLikeXPUKernel : public framework::OpKernel<T> { ...@@ -31,6 +31,7 @@ class FillAnyLikeXPUKernel : public framework::OpKernel<T> {
using XPUInTDType = typename XPUTypeTrait<T>::Type; using XPUInTDType = typename XPUTypeTrait<T>::Type;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -63,7 +64,7 @@ class FillAnyLikeXPUKernel : public framework::OpKernel<T> { ...@@ -63,7 +64,7 @@ class FillAnyLikeXPUKernel : public framework::OpKernel<T> {
pten::FullLikeKernel<T>( pten::FullLikeKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext< static_cast<const typename paddle::framework::ConvertToPtenContext<
paddle::platform::XPUDeviceContext>::TYPE&>(dev_ctx), paddle::platform::XPUDeviceContext>::TYPE&>(dev_ctx),
value, out); *x, value, pten::DataType::UNDEFINED, out);
} }
}; };
......
...@@ -28,9 +28,8 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape, ...@@ -28,9 +28,8 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,
void CreateInferMeta(const ScalarArray& shape, void CreateInferMeta(const ScalarArray& shape,
DataType dtype, DataType dtype,
DataLayout layout,
MetaTensor* out) { MetaTensor* out) {
CreateInferMetaBase(shape.GetData(), dtype, layout, out); CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
} }
} // namespace pten } // namespace pten
...@@ -33,9 +33,6 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape, ...@@ -33,9 +33,6 @@ void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataLayout layout, DataLayout layout,
MetaTensor* out); MetaTensor* out);
void CreateInferMeta(const ScalarArray& shape, void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out);
DataType dtype,
DataLayout layout,
MetaTensor* out);
} // namespace pten } // namespace pten
...@@ -79,13 +79,10 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { ...@@ -79,13 +79,10 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void CreateLikeInferMeta(const MetaTensor& x, void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
DataType dtype,
DataLayout layout,
MetaTensor* out) {
out->set_dims(x.dims()); out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
out->set_layout(layout == DataLayout::UNDEFINED ? x.layout() : layout); out->set_layout(x.layout());
} }
static pten::framework::DDim ValidateShape( static pten::framework::DDim ValidateShape(
......
...@@ -41,10 +41,7 @@ void FlattenInferMeta(const MetaTensor& x, ...@@ -41,10 +41,7 @@ void FlattenInferMeta(const MetaTensor& x,
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CreateLikeInferMeta(const MetaTensor& x, void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
DataType dtype,
DataLayout layout,
MetaTensor* out);
void InferMetaFromVecValue(const MetaTensor& x, void InferMetaFromVecValue(const MetaTensor& x,
const std::vector<int64_t>& shape, const std::vector<int64_t>& shape,
......
...@@ -16,7 +16,62 @@ limitations under the License. */ ...@@ -16,7 +16,62 @@ limitations under the License. */
#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
namespace pten {
template <typename T, typename Context, typename VType>
void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
dev_ctx.template Alloc<T>(tensor);
auto t = pten::EigenVector<T>::Flatten(*tensor);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(val));
}
template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
DenseTensor* out) {
out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData()));
FullValue<T>(dev_ctx, out, val.to<T>());
}
template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DataType dtype,
DenseTensor* out) {
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, pten::dtype::float16>::value,
float,
T>::type>::type;
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
pten::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f.",
typeid(T).name(),
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value)));
FullValue<T>(dev_ctx, out, value);
}
} // namespace pten
PT_REGISTER_KERNEL(full, PT_REGISTER_KERNEL(full,
CPU, CPU,
......
...@@ -23,12 +23,16 @@ namespace pten { ...@@ -23,12 +23,16 @@ namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void EmptyKernel(const Context& dev_ctx, void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype,
DenseTensor* out) { DenseTensor* out) {
out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData())); out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData()));
} }
template <typename T, typename Context> template <typename T, typename Context>
void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { void EmptyLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType dtype,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
} }
......
...@@ -25,10 +25,14 @@ namespace pten { ...@@ -25,10 +25,14 @@ namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void EmptyKernel(const Context& dev_ctx, void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out); void EmptyLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType dtype,
DenseTensor* out);
// TODO(chenweihang): the tensor creation method need to be replaced later, // TODO(chenweihang): the tensor creation method need to be replaced later,
// all kernel api call Empty here instead of making tensor self // all kernel api call Empty here instead of making tensor self
...@@ -52,27 +56,22 @@ DenseTensor Empty(const Context& dev_ctx) { ...@@ -52,27 +56,22 @@ DenseTensor Empty(const Context& dev_ctx) {
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx, DenseTensor Empty(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype = DataType::FLOAT32, DataType dtype = DataType::FLOAT32) {
Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) {
auto dense_out = Empty<T, Context>(dev_ctx); auto dense_out = Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
CreateInferMeta(shape, dtype, layout, &meta_out); CreateInferMeta(shape, dtype, &meta_out);
EmptyKernel<T, Context>(dev_ctx, shape, &dense_out); EmptyKernel<T, Context>(dev_ctx, shape, dtype, &dense_out);
return dense_out; return dense_out;
} }
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor EmptyLike( DenseTensor EmptyLike(const Context& dev_ctx,
const Context& dev_ctx, const DenseTensor& x,
const DenseTensor& x, DataType dtype = DataType::UNDEFINED) {
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto dense_out = Empty<T, Context>(dev_ctx); auto dense_out = Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
CreateLikeInferMeta(x, dtype, layout, &meta_out); CreateLikeInferMeta(x, dtype, &meta_out);
EmptyLikeKernel<T, Context>(dev_ctx, &dense_out); EmptyLikeKernel<T, Context>(dev_ctx, x, dtype, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -27,39 +27,37 @@ template <typename T, typename Context> ...@@ -27,39 +27,37 @@ template <typename T, typename Context>
void FullKernel(const Context& dev_ctx, void FullKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DataType dtype,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx, void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DataType dtype,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Full(const Context& dev_ctx, DenseTensor Full(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DataType dtype = DataType::FLOAT32, DataType dtype = DataType::FLOAT32) {
Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) {
auto dense_out = Empty<T, Context>(dev_ctx); auto dense_out = Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
CreateInferMeta(shape, dtype, layout, &meta_out); CreateInferMeta(shape, dtype, &meta_out);
FullKernel<T, Context>(dev_ctx, shape, val, &dense_out); FullKernel<T, Context>(dev_ctx, shape, val, dtype, &dense_out);
return dense_out; return dense_out;
} }
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor FullLike( DenseTensor FullLike(const Context& dev_ctx,
const Context& dev_ctx, const DenseTensor& x,
const DenseTensor& x, const Scalar& val,
const Scalar& val, DataType dtype = DataType::UNDEFINED) {
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto dense_out = Empty<T, Context>(dev_ctx); auto dense_out = Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
CreateLikeInferMeta(x, dtype, layout, &meta_out); CreateLikeInferMeta(x, dtype, &meta_out);
FullLikeKernel<T, Context>(dev_ctx, val, &dense_out); FullLikeKernel<T, Context>(dev_ctx, x, val, dtype, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -33,10 +33,11 @@ struct FullFuctor { ...@@ -33,10 +33,11 @@ struct FullFuctor {
} }
}; };
template <typename T, typename ContextT> template <typename T, typename Context>
void FullKernel(const ContextT& dev_ctx, void FullKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DataType dtype,
DenseTensor* out) { DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData())); out->Resize(paddle::framework::make_ddim(shape.GetData()));
int numel = out->numel(); int numel = out->numel();
...@@ -53,9 +54,11 @@ void FullKernel(const ContextT& dev_ctx, ...@@ -53,9 +54,11 @@ void FullKernel(const ContextT& dev_ctx,
} }
} }
template <typename T, typename ContextT> template <typename T, typename Context>
void FullLikeKernel(const ContextT& dev_ctx, void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DataType dtype,
DenseTensor* out) { DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<float>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
......
...@@ -57,6 +57,7 @@ template <typename T, typename Context> ...@@ -57,6 +57,7 @@ template <typename T, typename Context>
void FullKernel(const Context& dev_ctx, void FullKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DataType dtype,
DenseTensor* out) { DenseTensor* out) {
out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData())); out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData()));
FullValueXPU<T>(dev_ctx, out, val.to<T>()); FullValueXPU<T>(dev_ctx, out, val.to<T>());
...@@ -64,7 +65,9 @@ void FullKernel(const Context& dev_ctx, ...@@ -64,7 +65,9 @@ void FullKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx, void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DataType dtype,
DenseTensor* out) { DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<float>();
using XPUInTDType = typename XPUTypeTrait<T>::Type; using XPUInTDType = typename XPUTypeTrait<T>::Type;
......
...@@ -18,11 +18,11 @@ namespace pten { ...@@ -18,11 +18,11 @@ namespace pten {
KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("ShapeTensor")) { if (ctx.HasInput("ShapeTensor")) {
return KernelSignature("empty", {}, {"ShapeTensor"}, {"Out"}); return KernelSignature("empty", {}, {"ShapeTensor", "dtype"}, {"Out"});
} else if (ctx.InputSize("ShapeTensorList") > 0) { } else if (ctx.InputSize("ShapeTensorList") > 0) {
return KernelSignature("empty", {}, {"ShapeTensorList"}, {"Out"}); return KernelSignature("empty", {}, {"ShapeTensorList", "dtype"}, {"Out"});
} else { } else {
return KernelSignature("empty", {}, {"shape"}, {"Out"}); return KernelSignature("empty", {}, {"shape", "dtype"}, {"Out"});
} }
} }
......
...@@ -18,7 +18,7 @@ namespace pten { ...@@ -18,7 +18,7 @@ namespace pten {
KernelSignature FillAnyLikeOpArgumentMapping( KernelSignature FillAnyLikeOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("full_like", {}, {"value"}, {"Out"}); return KernelSignature("full_like", {"X"}, {"value", "dtype"}, {"Out"});
} }
} // namespace pten } // namespace pten
......
...@@ -23,42 +23,46 @@ KernelSignature FillConstantOpArgumentMapping( ...@@ -23,42 +23,46 @@ KernelSignature FillConstantOpArgumentMapping(
if (ctx.HasInput("ShapeTensor")) { if (ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature( return KernelSignature(
"full", {}, {"ShapeTensor", "ValueTensor"}, {"Out"}); "full", {}, {"ShapeTensor", "ValueTensor", "dtype"}, {"Out"});
} else { } else {
const auto& str_value = const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value")); paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) { if (str_value.empty()) {
return KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"}); return KernelSignature(
"full", {}, {"ShapeTensor", "value", "dtype"}, {"Out"});
} else { } else {
return KernelSignature( return KernelSignature(
"full", {}, {"ShapeTensor", "str_value"}, {"Out"}); "full", {}, {"ShapeTensor", "str_value", "dtype"}, {"Out"});
} }
} }
} else if (ctx.InputSize("ShapeTensorList") > 0) { } else if (ctx.InputSize("ShapeTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature( return KernelSignature(
"full", {}, {"ShapeTensorList", "ValueTensor"}, {"Out"}); "full", {}, {"ShapeTensorList", "ValueTensor", "dtype"}, {"Out"});
} else { } else {
const auto& str_value = const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value")); paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) { if (str_value.empty()) {
return KernelSignature( return KernelSignature(
"full", {}, {"ShapeTensorList", "value"}, {"Out"}); "full", {}, {"ShapeTensorList", "value", "dtype"}, {"Out"});
} else { } else {
return KernelSignature( return KernelSignature(
"full", {}, {"ShapeTensorList", "str_value"}, {"Out"}); "full", {}, {"ShapeTensorList", "str_value", "dtype"}, {"Out"});
} }
} }
} else { } else {
if (ctx.HasInput("ValueTensor")) { if (ctx.HasInput("ValueTensor")) {
return KernelSignature("full", {}, {"shape", "ValueTensor"}, {"Out"}); return KernelSignature(
"full", {}, {"shape", "ValueTensor", "dtype"}, {"Out"});
} else { } else {
const auto& str_value = const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value")); paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) { if (str_value.empty()) {
return KernelSignature("full", {}, {"shape", "value"}, {"Out"}); return KernelSignature(
"full", {}, {"shape", "value", "dtype"}, {"Out"});
} else { } else {
return KernelSignature("full", {}, {"shape", "str_value"}, {"Out"}); return KernelSignature(
"full", {}, {"shape", "str_value", "dtype"}, {"Out"});
} }
} }
} }
......
...@@ -51,30 +51,28 @@ ...@@ -51,30 +51,28 @@
func : dot func : dot
- api : empty - api : empty
args : (ScalarArray shape, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU, DataLayout layout=DataLayout::NCHW) args : (ScalarArray shape, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU)
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateInferMeta func : CreateInferMeta
param : [shape, dtype, layout] param : [shape, dtype]
kernel : kernel :
func : empty func : empty
param : [shape] param : [shape, dtype]
data_type : dtype data_type : dtype
backend : place backend : place
layout : layout
- api : empty_like - api : empty_like
args : (Tensor x, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED, DataLayout layout = DataLayout::UNDEFINED) args : (Tensor x, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED)
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateLikeInferMeta func : CreateLikeInferMeta
param : [x, dtype, layout] param : [x, dtype]
kernel : kernel :
func : empty_like func : empty_like
param : [] param : [x, dtype]
data_type : dtype > x data_type : dtype > x
backend : place > x backend : place > x
layout : layout > x
- api : flatten - api : flatten
args : (Tensor x, int start_axis, int stop_axis) args : (Tensor x, int start_axis, int stop_axis)
...@@ -85,30 +83,28 @@ ...@@ -85,30 +83,28 @@
func : flatten func : flatten
- api : full - api : full
args : (ScalarArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU, DataLayout layout=DataLayout::NCHW) args : (ScalarArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU)
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateInferMeta func : CreateInferMeta
param : [shape, dtype, layout] param : [shape, dtype]
kernel : kernel :
func : full func : full
param : [shape, value] param : [shape, value, dtype]
data_type : dtype data_type : dtype
backend : place backend : place
layout : layout
- api : full_like - api : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED, DataLayout layout = DataLayout::UNDEFINED) args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED)
output: Tensor output: Tensor
infer_meta : infer_meta :
func : CreateLikeInferMeta func : CreateLikeInferMeta
param : [x, dtype, layout] param : [x, dtype]
kernel : kernel :
func : full_like func : full_like
param : [value] param : [x, value, dtype]
data_type : dtype > x data_type : dtype > x
backend : place > x backend : place > x
layout : layout > x
- api : matmul - api : matmul
args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false) args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
...@@ -136,9 +132,9 @@ ...@@ -136,9 +132,9 @@
func : multiply func : multiply
- api : ones_like - api : ones_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED)
output : Tensor output : Tensor
invoke : full_like(x, 1, dtype, place, layout) invoke : full_like(x, 1, dtype, place)
- api : reshape - api : reshape
args : (Tensor x, ScalarArray shape) args : (Tensor x, ScalarArray shape)
...@@ -185,6 +181,6 @@ ...@@ -185,6 +181,6 @@
data_type : x data_type : x
- api : zeros_like - api : zeros_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED)
output : Tensor output : Tensor
invoke : full_like(x, 0, dtype, place, layout) invoke : full_like(x, 0, dtype, place)
...@@ -358,8 +358,8 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. ...@@ -358,8 +358,8 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
""" """
if len(input_names) == 0: if len(input_names) == 0:
assert attr_backend_count > 0 and attr_layout_count > 0 and attr_data_type_count > 0, \ assert attr_backend_count > 0 and attr_data_type_count > 0, \
f"{api} api: When there is no input tensor, the args must have 'Backend', 'DataLayout' and 'DataType'." f"{api} api: When there is no input tensor, the args must have 'Backend' and 'DataType'."
kernel_select_args = "" kernel_select_args = ""
for input_name in input_names: for input_name in input_names:
......
...@@ -29,30 +29,36 @@ def gene_wrapped_infermeta_and_register(api): ...@@ -29,30 +29,36 @@ def gene_wrapped_infermeta_and_register(api):
PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});"""
if api.infer_meta['param'] is not None: if api.infer_meta['param'] is not None:
kernel_params = api.kernel['param']
if kernel_params is None:
kernel_params = api.inputs['names'] + api.attrs['names']
if kernel_params == api.infer_meta['param']:
return '', '', register_code
assert len(api.infer_meta['param']) <= len(kernel_params), \
f"{api.api} api: Parameters error. The params of infer_meta should be a subset of kernel params."
tensor_type_map = { tensor_type_map = {
'const Tensor&': 'const MetaTensor&', 'const Tensor&': 'const MetaTensor&',
'const std::vector<Tensor>&': 'const std::vector<MetaTensor>&', 'const std::vector<Tensor>&': 'const std::vector<MetaTensor>&',
'Tensor': 'MetaTensor*', 'Tensor': 'MetaTensor*',
'std::vector<Tensor>': 'std::vector<MetaTensor>*', 'std::vector<Tensor>': 'std::vector<MetaTensor>*',
} }
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api) wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
args = [] args = []
check_args = []
for input_name in api.inputs['names']: for input_name in api.inputs['names']:
args.append(tensor_type_map[api.inputs['input_info'][ if input_name in kernel_params:
input_name]] + ' ' + input_name) args.append(tensor_type_map[api.inputs['input_info'][
check_args.append(input_name) input_name]] + ' ' + input_name)
for attr_name in api.attrs['names']: for attr_name in api.attrs['names']:
args.append(api.attrs['attr_info'][attr_name][0] + ' ' + if attr_name in kernel_params:
attr_name) args.append(api.attrs['attr_info'][attr_name][0] + ' ' +
check_args.append(attr_name) attr_name)
for i, out_type in enumerate(api.outputs['types']): for i, out_type in enumerate(api.outputs['types']):
args.append(tensor_type_map[out_type] + ' ' + api.outputs[ args.append(tensor_type_map[out_type] + ' ' + api.outputs[
'names'][i]) 'names'][i])
if check_args == api.infer_meta['param']:
return '', '', register_code
invoke_param = api.infer_meta['param'] invoke_param = api.infer_meta['param']
invoke_param.extend(api.outputs['names']) invoke_param.extend(api.outputs['names'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册