未验证 提交 abd4ab9c 编写于 作者: Z zyfncg 提交者: GitHub

[PTen] Adjust the param of full_like API in pten (#37088)

* adjust the param of full_like api  in pten

* adjust the code format

* adjust the code format

* adjust the code format
上级 1fe4513c
...@@ -50,8 +50,7 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { ...@@ -50,8 +50,7 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs( framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("fill_any_like", {"X"}, {"value"}, return framework::KernelSignature("fill_any_like", {}, {"value"}, {"Out"});
{"Out"});
} }
}; };
......
...@@ -34,7 +34,6 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,6 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
float, T>::type>::type; float, T>::type>::type;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
...@@ -62,12 +61,11 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -62,12 +61,11 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
std::isnan(value), false, std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN.")); platform::errors::InvalidArgument("The filled value is NaN."));
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
const auto& dev_ctx = context.template device_context<DeviceContext>(); const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel // call new kernel
pten::FillAnyLike<T>(dev_ctx, *pt_x, value, pt_out.get()); pten::FillAnyLike<T>(dev_ctx, value, pt_out.get());
} }
}; };
......
...@@ -29,11 +29,19 @@ Tensor full(const std::vector<int64_t>& shape, ...@@ -29,11 +29,19 @@ Tensor full(const std::vector<int64_t>& shape,
Tensor full_like(const Tensor& x, Tensor full_like(const Tensor& x,
const Scalar& value, const Scalar& value,
DataType dtype = DataType::UNDEFINED); DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
Tensor ones_like(const Tensor& x, DataType dtype = DataType::UNDEFINED); DataLayout layout = DataLayout::UNDEFINED);
Tensor zeros_like(const Tensor& x, DataType dtype = DataType::UNDEFINED); Tensor ones_like(const Tensor& x,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
DataLayout layout = DataLayout::UNDEFINED);
Tensor zeros_like(const Tensor& x,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED,
DataLayout layout = DataLayout::UNDEFINED);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -63,15 +63,22 @@ Tensor full(const std::vector<int64_t>& shape, ...@@ -63,15 +63,22 @@ Tensor full(const std::vector<int64_t>& shape,
Tensor full_like(const Tensor& x, Tensor full_like(const Tensor& x,
const Scalar& value, const Scalar& value,
paddle::experimental::DataType dtype) { DataType dtype,
Backend backend,
DataLayout layout) {
// 1. Get kernel signature and kernel // 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
DataType kernel_data_type =
dtype == DataType::UNDEFINED ? kernel_key.dtype() : dtype;
Backend kernel_backend =
backend == Backend::UNDEFINED ? kernel_key.backend() : backend;
DataLayout kernel_layout =
layout == DataLayout::UNDEFINED ? kernel_key.layout() : layout;
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"fill_any_like", "fill_any_like", {kernel_backend, kernel_layout, kernel_data_type});
{kernel_key.backend(),
kernel_key.layout(),
dtype == DataType::UNDEFINED ? kernel_key.dtype() : dtype});
// 2. Get Device Context // 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
...@@ -79,21 +86,16 @@ Tensor full_like(const Tensor& x, ...@@ -79,21 +86,16 @@ Tensor full_like(const Tensor& x,
// 3. Auto data transform // 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl()); auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(value); kernel_context.EmplaceBackAttr(value);
// 4. InferShape // 4. InferShape
auto out_meta = UnchangedInferShape(dense_x->meta()); auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout);
// 5. Prepare outputs // 5. Prepare outputs
Tensor out; Tensor out;
// InferDataType
if (dtype != pten::DataType::UNDEFINED) {
const_cast<pten::DenseTensorMeta::DataType&>(out_meta.type) = dtype;
}
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend())); pten::TransToFluidPlace(kernel_backend));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta); auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out); kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out); out.set_impl(dense_out);
...@@ -104,12 +106,18 @@ Tensor full_like(const Tensor& x, ...@@ -104,12 +106,18 @@ Tensor full_like(const Tensor& x,
return out; return out;
} }
Tensor ones_like(const Tensor& x, DataType dtype) { Tensor ones_like(const Tensor& x,
return full_like(x, 1, dtype); DataType dtype,
Backend backend,
DataLayout layout) {
return full_like(x, 1, dtype, backend, layout);
} }
Tensor zeros_like(const Tensor& x, DataType dtype) { Tensor zeros_like(const Tensor& x,
return full_like(x, 0, dtype); DataType dtype,
Backend backend,
DataLayout layout) {
return full_like(x, 0, dtype, backend, layout);
} }
} // namespace experimental } // namespace experimental
......
...@@ -24,15 +24,19 @@ namespace pten { ...@@ -24,15 +24,19 @@ namespace pten {
// TODO(YuanRisheng) This function name should be same as User API name. // TODO(YuanRisheng) This function name should be same as User API name.
// TODO(zyfncg) Automatic code generation // TODO(zyfncg) Automatic code generation
template <typename T, typename ContextT> template <typename T, typename ContextT>
DenseTensor FillAnyLike(const ContextT& dev_ctx, DenseTensor FillAnyLike(
const DenseTensor& x, const ContextT& dev_ctx,
const Scalar& val) { const DenseTensor& x,
auto out_meta = UnchangedInferShape(x.meta()); const Scalar& val,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = FullLikeInferShape(x.meta(), dtype, layout);
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta); pten::DenseTensor dense_out(allocator, out_meta);
FillAnyLike<T>(dev_ctx, x, val, &dense_out); FillAnyLike<T>(dev_ctx, val, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -74,4 +74,12 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, ...@@ -74,4 +74,12 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta,
return return_meta; return return_meta;
} }
DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataType dtype,
DataLayout layout) {
return {dtype == DataType::UNDEFINED ? x_meta.type : dtype,
x_meta.dims,
layout == DataLayout::UNDEFINED ? x_meta.layout : layout};
}
} // namespace pten } // namespace pten
...@@ -41,4 +41,8 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, ...@@ -41,4 +41,8 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta,
int start_axis, int start_axis,
int stop_axis); int stop_axis);
DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta,
DataType dtype,
DataLayout layout);
} // namespace pten } // namespace pten
...@@ -21,7 +21,6 @@ namespace pten { ...@@ -21,7 +21,6 @@ namespace pten {
template <typename T> template <typename T>
void FillAnyLike(const CPUContext& dev_ctx, void FillAnyLike(const CPUContext& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DenseTensor* out) { DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<float>();
......
...@@ -25,7 +25,6 @@ using CPUContext = paddle::platform::CPUDeviceContext; ...@@ -25,7 +25,6 @@ using CPUContext = paddle::platform::CPUDeviceContext;
template <typename T> template <typename T>
void FillAnyLike(const CPUContext& dev_ctx, void FillAnyLike(const CPUContext& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DenseTensor* out); DenseTensor* out);
......
...@@ -21,7 +21,6 @@ namespace pten { ...@@ -21,7 +21,6 @@ namespace pten {
template <typename T> template <typename T>
void FillAnyLike(const CUDAContext& dev_ctx, void FillAnyLike(const CUDAContext& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DenseTensor* out) { DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<float>();
......
...@@ -28,7 +28,6 @@ using CUDAContext = paddle::platform::CUDADeviceContext; ...@@ -28,7 +28,6 @@ using CUDAContext = paddle::platform::CUDADeviceContext;
template <typename T> template <typename T>
void FillAnyLike(const CUDAContext& dev_ctx, void FillAnyLike(const CUDAContext& dev_ctx,
const DenseTensor& x,
const Scalar& val, const Scalar& val,
DenseTensor* out); DenseTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册