未验证 提交 0c02d2ed 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Adjust the format of full kernel (#38596)

* adjust the full kernel

* remove creation.h

* use Empty to create tensor in full
上级 c1adced7
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/creation.h"
#include "paddle/pten/kernels/full_kernel.h"
namespace paddle {
namespace operators {
......@@ -65,7 +65,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel
pten::FullLike<T>(dev_ctx, value, pt_out.get());
pten::FullLikeKernel<T>(dev_ctx, value, pt_out.get());
}
};
......
......@@ -16,7 +16,6 @@ limitations under the License. */
// developer apis
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/creation.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/include/linalg.h"
#include "paddle/pten/include/manipulation.h"
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h"
namespace pten {
// TODO(YuanRisheng) This function name should be same as User API name.
// TODO(zyfncg) Automatic code generation
template <typename T, typename ContextT>
DenseTensor Full(const ContextT& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) {
auto out_meta = CreateInferMeta(shape, dtype, layout);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Full<T, ContextT>(dev_ctx, shape, val, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor FullLike(
const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
FullLike<T, ContextT>(dev_ctx, val, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -21,7 +21,7 @@ limitations under the License. */
PT_REGISTER_CTX_KERNEL(full,
CPU,
ALL_LAYOUT,
pten::Full,
pten::FullKernel,
float,
double,
uint8_t,
......@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(full,
PT_REGISTER_CTX_KERNEL(full_like,
CPU,
ALL_LAYOUT,
pten::FullLike,
pten::FullLikeKernel,
float,
double,
int,
......
......@@ -18,15 +18,47 @@
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void Full(const Context& dev_ctx,
void FullKernel(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
template <typename T, typename Context>
void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out);
void FullLikeKernel(const Context& dev_ctx,
const Scalar& val,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Full(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) {
auto out_meta = CreateInferMeta(shape, dtype, layout);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
FullKernel<T, Context>(dev_ctx, shape, val, &dense_out);
return dense_out;
}
template <typename T, typename Context>
DenseTensor FullLike(
const Context& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
FullLikeKernel<T, Context>(dev_ctx, val, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -21,7 +21,7 @@ limitations under the License. */
PT_REGISTER_CTX_KERNEL(full,
GPU,
ALL_LAYOUT,
pten::Full,
pten::FullKernel,
float,
double,
uint8_t,
......@@ -36,7 +36,7 @@ PT_REGISTER_CTX_KERNEL(full,
PT_REGISTER_CTX_KERNEL(full_like,
GPU,
ALL_LAYOUT,
pten::FullLike,
pten::FullLikeKernel,
float,
double,
int,
......
......@@ -24,7 +24,7 @@ limitations under the License. */
namespace pten {
template <typename Context, typename T, typename VType>
template <typename T, typename Context, typename VType>
void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
tensor->mutable_data<T>();
auto t = pten::EigenVector<T>::Flatten(*tensor);
......@@ -32,16 +32,18 @@ void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
}
template <typename T, typename Context>
void Full(const Context& dev_ctx,
void FullKernel(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
FullValue<Context, T>(dev_ctx, out, val.to<T>());
FullValue<T>(dev_ctx, out, val.to<T>());
}
template <typename T, typename Context>
void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) {
void FullLikeKernel(const Context& dev_ctx,
const Scalar& val,
DenseTensor* out) {
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
......@@ -66,7 +68,7 @@ void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) {
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value)));
FullValue<Context, T>(dev_ctx, out, value);
FullValue<T>(dev_ctx, out, value);
}
} // namespace pten
......@@ -15,7 +15,8 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/creation.h"
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册