未验证 提交 abd4ab9c 编写于 作者: Z zyfncg 提交者: GitHub

[PTen] Adjust the param of full_like API in pten (#37088)

* adjust the param of full_like api  in pten

* adjust the code format

* adjust the code format

* adjust the code format
上级 1fe4513c
......@@ -50,8 +50,7 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("fill_any_like", {"X"}, {"value"},
{"Out"});
return framework::KernelSignature("fill_any_like", {}, {"value"}, {"Out"});
}
};
......
......@@ -34,7 +34,6 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
float, T>::type>::type;
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
......@@ -62,12 +61,11 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN."));
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel
pten::FillAnyLike<T>(dev_ctx, *pt_x, value, pt_out.get());
pten::FillAnyLike<T>(dev_ctx, value, pt_out.get());
}
};
......
......@@ -29,11 +29,19 @@ Tensor full(const std::vector<int64_t>& shape,
Tensor full_like(const Tensor& x,
const Scalar& value,
DataType dtype = DataType::UNDEFINED);
Tensor ones_like(const Tensor& x, DataType dtype = DataType::UNDEFINED);
Tensor zeros_like(const Tensor& x, DataType dtype = DataType::UNDEFINED);
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
DataLayout layout = DataLayout::UNDEFINED);
Tensor ones_like(const Tensor& x,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
DataLayout layout = DataLayout::UNDEFINED);
Tensor zeros_like(const Tensor& x,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
DataLayout layout = DataLayout::UNDEFINED);
} // namespace experimental
} // namespace paddle
......@@ -63,15 +63,22 @@ Tensor full(const std::vector<int64_t>& shape,
Tensor full_like(const Tensor& x,
const Scalar& value,
paddle::experimental::DataType dtype) {
DataType dtype,
Backend backend,
DataLayout layout) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
DataType kernel_data_type =
dtype == DataType::UNDEFINED ? kernel_key.dtype() : dtype;
Backend kernel_backend =
backend == Backend::UNDEFINED ? kernel_key.backend() : backend;
DataLayout kernel_layout =
layout == DataLayout::UNDEFINED ? kernel_key.layout() : layout;
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"fill_any_like",
{kernel_key.backend(),
kernel_key.layout(),
dtype == DataType::UNDEFINED ? kernel_key.dtype() : dtype});
"fill_any_like", {kernel_backend, kernel_layout, kernel_data_type});
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
......@@ -79,21 +86,16 @@ Tensor full_like(const Tensor& x,
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(value);
// 4. InferShape
auto out_meta = UnchangedInferShape(dense_x->meta());
auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout);
// 5. Prepare outputs
Tensor out;
// InferDataType
if (dtype != pten::DataType::UNDEFINED) {
const_cast<pten::DenseTensorMeta::DataType&>(out_meta.type) = dtype;
}
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
pten::TransToFluidPlace(kernel_backend));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);
......@@ -104,12 +106,18 @@ Tensor full_like(const Tensor& x,
return out;
}
Tensor ones_like(const Tensor& x, DataType dtype) {
return full_like(x, 1, dtype);
Tensor ones_like(const Tensor& x,
DataType dtype,
Backend backend,
DataLayout layout) {
return full_like(x, 1, dtype, backend, layout);
}
Tensor zeros_like(const Tensor& x, DataType dtype) {
return full_like(x, 0, dtype);
Tensor zeros_like(const Tensor& x,
DataType dtype,
Backend backend,
DataLayout layout) {
return full_like(x, 0, dtype, backend, layout);
}
} // namespace experimental
......
......@@ -24,15 +24,19 @@ namespace pten {
// TODO(YuanRisheng) This function name should be same as User API name.
// TODO(zyfncg) Automatic code generation
template <typename T, typename ContextT>
DenseTensor FillAnyLike(const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& val) {
auto out_meta = UnchangedInferShape(x.meta());
DenseTensor FillAnyLike(
const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = FullLikeInferShape(x.meta(), dtype, layout);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
FillAnyLike<T>(dev_ctx, x, val, &dense_out);
FillAnyLike<T>(dev_ctx, val, &dense_out);
return dense_out;
}
......
......@@ -74,4 +74,12 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta,
return return_meta;
}
DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataType dtype,
DataLayout layout) {
return {dtype == DataType::UNDEFINED ? x_meta.type : dtype,
x_meta.dims,
layout == DataLayout::UNDEFINED ? x_meta.layout : layout};
}
} // namespace pten
......@@ -41,4 +41,8 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta,
int start_axis,
int stop_axis);
DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataType dtype,
DataLayout layout);
} // namespace pten
......@@ -21,7 +21,6 @@ namespace pten {
template <typename T>
void FillAnyLike(const CPUContext& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DenseTensor* out) {
auto value = val.to<float>();
......
......@@ -25,7 +25,6 @@ using CPUContext = paddle::platform::CPUDeviceContext;
template <typename T>
void FillAnyLike(const CPUContext& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DenseTensor* out);
......
......@@ -21,7 +21,6 @@ namespace pten {
template <typename T>
void FillAnyLike(const CUDAContext& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DenseTensor* out) {
auto value = val.to<float>();
......
......@@ -28,7 +28,6 @@ using CUDAContext = paddle::platform::CUDADeviceContext;
template <typename T>
void FillAnyLike(const CUDAContext& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DenseTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册