未验证 提交 0c02d2ed 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Adjust the format of full kernel (#38596)

* adjust the full kernel

* remove creation.h

* use Empty to create tensor in full
上级 c1adced7
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/creation.h" #include "paddle/pten/kernels/full_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -65,7 +65,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -65,7 +65,7 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
const auto& dev_ctx = context.template device_context<DeviceContext>(); const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel // call new kernel
pten::FullLike<T>(dev_ctx, value, pt_out.get()); pten::FullLikeKernel<T>(dev_ctx, value, pt_out.get());
} }
}; };
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
// developer apis // developer apis
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/creation.h"
#include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/infermeta.h"
#include "paddle/pten/include/linalg.h" #include "paddle/pten/include/linalg.h"
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/include/manipulation.h"
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h"
namespace pten {
// TODO(YuanRisheng) This function name should be same as User API name.
// TODO(zyfncg) Automatic code generation
template <typename T, typename ContextT>
DenseTensor Full(const ContextT& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) {
auto out_meta = CreateInferMeta(shape, dtype, layout);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Full<T, ContextT>(dev_ctx, shape, val, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor FullLike(
const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
FullLike<T, ContextT>(dev_ctx, val, &dense_out);
return dense_out;
}
} // namespace pten
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_CTX_KERNEL(full,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Full, pten::FullKernel,
float, float,
double, double,
uint8_t, uint8_t,
...@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(full, ...@@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(full,
PT_REGISTER_CTX_KERNEL(full_like, PT_REGISTER_CTX_KERNEL(full_like,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullLike, pten::FullLikeKernel,
float, float,
double, double,
int, int,
......
...@@ -18,15 +18,47 @@ ...@@ -18,15 +18,47 @@
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Full(const Context& dev_ctx, void FullKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
const Scalar& val,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Full(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU, // Is backend needed here?
DataLayout layout = DataLayout::NCHW) {
auto out_meta = CreateInferMeta(shape, dtype, layout);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
FullKernel<T, Context>(dev_ctx, shape, val, &dense_out);
return dense_out;
}
template <typename T, typename Context> template <typename T, typename Context>
void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out); DenseTensor FullLike(
const Context& dev_ctx,
const DenseTensor& x,
const Scalar& val,
DataType dtype = DataType::UNDEFINED,
Backend backend = Backend::UNDEFINED, // Is backend needed here?
DataLayout layout = DataLayout::UNDEFINED) {
auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
FullLikeKernel<T, Context>(dev_ctx, val, &dense_out);
return dense_out;
}
} // namespace pten } // namespace pten
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_CTX_KERNEL(full,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Full, pten::FullKernel,
float, float,
double, double,
uint8_t, uint8_t,
...@@ -36,7 +36,7 @@ PT_REGISTER_CTX_KERNEL(full, ...@@ -36,7 +36,7 @@ PT_REGISTER_CTX_KERNEL(full,
PT_REGISTER_CTX_KERNEL(full_like, PT_REGISTER_CTX_KERNEL(full_like,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::FullLike, pten::FullLikeKernel,
float, float,
double, double,
int, int,
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename Context, typename T, typename VType> template <typename T, typename Context, typename VType>
void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
tensor->mutable_data<T>(); tensor->mutable_data<T>();
auto t = pten::EigenVector<T>::Flatten(*tensor); auto t = pten::EigenVector<T>::Flatten(*tensor);
...@@ -32,16 +32,18 @@ void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { ...@@ -32,16 +32,18 @@ void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
} }
template <typename T, typename Context> template <typename T, typename Context>
void Full(const Context& dev_ctx, void FullKernel(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out) { DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData())); out->Resize(paddle::framework::make_ddim(shape.GetData()));
FullValue<Context, T>(dev_ctx, out, val.to<T>()); FullValue<T>(dev_ctx, out, val.to<T>());
} }
template <typename T, typename Context> template <typename T, typename Context>
void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) { void FullLikeKernel(const Context& dev_ctx,
const Scalar& val,
DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<float>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
...@@ -66,7 +68,7 @@ void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) { ...@@ -66,7 +68,7 @@ void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) {
static_cast<CommonType>(std::numeric_limits<T>::lowest()), static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()), static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value))); static_cast<float>(value)));
FullValue<Context, T>(dev_ctx, out, value); FullValue<T>(dev_ctx, out, value);
} }
} // namespace pten } // namespace pten
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/include/creation.h" #include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/full_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册