未验证 提交 69e9e9d5 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Remove fill_any_like kernel register in fluid (#39807)

* remove fill_any_like kernel in fluid and fix data transform bug

* support scalar in infershpe

* recover infershape in fill_and_like
上级 edc3ba13
......@@ -1972,6 +1972,9 @@ Scope* OperatorWithKernel::PreparePtenData(
continue;
}
if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue;
}
auto expected_place = phi::TransToPtenPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue;
......
......@@ -479,6 +479,9 @@ void PreparePtenData(const phi::Kernel& pt_kernel,
auto var = ins_vector[offset];
const auto* tensor_in = GetTensorFromVar(var->Var());
if (tensor_in && tensor_in->IsInitialized()) {
if (in_def.backend == phi::Backend::ALL_BACKEND) {
continue;
}
auto expected_place = phi::TransToPtenPlace(in_def.backend);
if (platform::is_same_place(tensor_in->place(), expected_place)) {
continue;
......
......@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_like_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -91,14 +92,3 @@ REGISTER_OPERATOR(
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FillAnyLikeVarTypeInference)
REGISTER_OP_CPU_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, bool>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_any_like_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, bool>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <limits>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FillAnyLikeKernel : public framework::OpKernel<T> {
public:
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, platform::float16>::value,
float, T>::type>::type;
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
// TODO(fangzeyang): Once context.Attribute supports double dtype, this
// kernel should be updated to support double dtype, too.
float value = context.Attr<float>("value");
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
platform::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f.",
typeid(T).name(),
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()), value));
PADDLE_ENFORCE_EQ(
std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN."));
const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel
phi::FullLikeKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, value, phi::DataType::UNDEFINED, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_like_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/fill_any_like_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
......
......@@ -91,8 +91,9 @@ cc_library(pten_tensor SRCS tensor_method.cc DEPS pten_tensor_raw pten_function_
cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform wrapped_infermeta)
cc_library(pten_dygraph_api SRCS ${dygraph_api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten)
......@@ -25,7 +25,6 @@ namespace experimental {
template <typename T>
class ScalarBase {
public:
bool FromTensor() const { return is_from_tensor_; }
// Constructor support implicit
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
......@@ -157,6 +156,10 @@ class ScalarBase {
CopyScalar(other, this);
}
bool FromTensor() const { return is_from_tensor_; }
void SetFromTensor(bool from_tensor) { is_from_tensor_ = from_tensor; }
template <typename RT>
inline RT to() const {
switch (dtype_) {
......
......@@ -99,4 +99,6 @@ PD_REGISTER_KERNEL(full_like,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -123,4 +123,6 @@ PD_REGISTER_KERNEL(full_like,
int,
int64_t,
bool,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -139,4 +139,6 @@ PD_REGISTER_KERNEL(full_like,
float,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册