未验证 提交 ad716551 编写于 作者: H HongyuJia 提交者: GitHub

[phi] Transfer fluid fill_any to PHI fill (#44879)

* transfer kernel, make complete

* add fill_sig file

* fix code style

* fix fill_sig, add yaml, modify python API

* fix inplace, add inplace testcase

* deprecated_op_names append fill

* resolve comments, add test_backward
上级 cf5742ac
......@@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
......@@ -34,30 +38,11 @@ class FillAnyOpMaker : public framework::OpProtoAndCheckerMaker {
class FillAnyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillAny");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillAny");
auto x_dims = context->GetInputDim("X");
context->SetOutputDim("Out", x_dims);
}
};
class FillAnyGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"mul");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
template <typename T>
......@@ -82,31 +67,22 @@ DECLARE_INPLACE_OP_INFERER(FillAnyGradInplaceInferer,
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fill_any,
FillInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(fill_any_grad,
FillAnyInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(fill_any,
ops::FillAnyOp,
ops::FillAnyOpMaker,
ops::FillAnyGradOpMaker<paddle::framework::OpDesc>,
ops::FillAnyGradOpMaker<paddle::imperative::OpBase>,
ops::FillAnyOpInplaceInferer);
ops::FillAnyOpInplaceInferer,
FillInferShapeFunctor);
REGISTER_OPERATOR(fill_any_grad,
ops::FillAnyGradOp,
ops::FillAnyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
fill_any,
ops::FillAnyKernel<phi::CPUContext, float>,
ops::FillAnyKernel<phi::CPUContext, double>,
ops::FillAnyKernel<phi::CPUContext, int64_t>,
ops::FillAnyKernel<phi::CPUContext, int>,
ops::FillAnyKernel<phi::CPUContext, paddle::platform::float16>,
ops::FillAnyKernel<phi::CPUContext, bool>);
REGISTER_OP_CPU_KERNEL(
fill_any_grad,
ops::FillAnyGradKernel<phi::CPUContext, float>,
ops::FillAnyGradKernel<phi::CPUContext, double>,
ops::FillAnyGradKernel<phi::CPUContext, int64_t>,
ops::FillAnyGradKernel<phi::CPUContext, int>,
ops::FillAnyGradKernel<phi::CPUContext, paddle::platform::float16>,
ops::FillAnyGradKernel<phi::CPUContext, bool>);
ops::FillAnyGradInplaceInferer,
FillAnyInferShapeFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_any,
ops::FillAnyKernel<phi::GPUContext, float>,
ops::FillAnyKernel<phi::GPUContext, double>,
ops::FillAnyKernel<phi::GPUContext, int64_t>,
ops::FillAnyKernel<phi::GPUContext, int>,
ops::FillAnyKernel<phi::GPUContext, paddle::platform::float16>,
ops::FillAnyKernel<phi::GPUContext, bool>);
REGISTER_OP_CUDA_KERNEL(
fill_any_grad,
ops::FillAnyGradKernel<phi::GPUContext, float>,
ops::FillAnyGradKernel<phi::GPUContext, double>,
ops::FillAnyGradKernel<phi::GPUContext, int64_t>,
ops::FillAnyGradKernel<phi::GPUContext, int>,
ops::FillAnyGradKernel<phi::GPUContext, paddle::platform::float16>,
ops::FillAnyGradKernel<phi::GPUContext, bool>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FillAnyKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
auto floatvar = ctx.template Attr<float>("value_float");
auto intvar = ctx.template Attr<int>("value_int");
auto isfloat = ((typeid(float) == typeid(T)) ||
(typeid(double) == typeid(T) ||
typeid(paddle::platform::float16) == typeid(T)));
T fill_var = static_cast<T>(floatvar);
if (!isfloat) {
fill_var = static_cast<T>(intvar);
}
PADDLE_ENFORCE_EQ(
std::isnan(static_cast<double>(fill_var)),
false,
platform::errors::InvalidArgument("fill value should not be NaN,"
" but received NaN"));
out->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> functor;
functor(reinterpret_cast<const DeviceContext &>(dev_ctx),
out,
static_cast<T>(fill_var));
}
};
template <typename DeviceContext, typename T>
class FillAnyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<DeviceContext, T> functor;
functor(reinterpret_cast<const DeviceContext &>(dev_ctx), dx, T(0));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -864,6 +864,17 @@
data_type : dtype
backend : place
- api : fill
args : (Tensor x, Scalar value)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : fill
inplace : (x -> out)
backward: fill_grad
- api : fill_diagonal
args : (Tensor x, float value, int offset, bool wrap)
output : Tensor(out)
......
......@@ -811,7 +811,7 @@
infer_meta :
func : UnchangedInferMeta
invoke : zeros_like(out_grad, DataType::UNDEFINED, {})
- backward_api : fill_diagonal_grad
forward : fill_diagonal (Tensor x, float value, int offset, bool wrap) -> Tensor(out)
args : (Tensor out_grad, float value, int offset, bool wrap)
......@@ -831,6 +831,17 @@
func : fill_diagonal_tensor_grad
inplace : (out_grad -> x_grad)
- backward_api : fill_grad
forward : fill (Tensor x, Scalar value) -> Tensor(out)
args : (Tensor out_grad, Scalar value)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : fill_grad
inplace : (out_grad -> x_grad)
- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
......
......@@ -51,6 +51,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"squeeze_grad",
"isfinite",
"matmul",
"fill",
"matmul_grad",
"matmul_grad_grad",
"max",
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/fill_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(fill_grad,
CPU,
ALL_LAYOUT,
phi::FillGradKernel,
float,
double,
int64_t,
int,
paddle::platform::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/fill_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(fill,
CPU,
ALL_LAYOUT,
phi::FillKernel,
float,
double,
int64_t,
int,
paddle::platform::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FillGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const Scalar& value,
DenseTensor* in_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FillKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& value,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fill_grad_kernel_impl.h"
PD_REGISTER_KERNEL(fill_grad,
GPU,
ALL_LAYOUT,
phi::FillGradKernel,
float,
double,
int64_t,
int,
paddle::platform::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fill_kernel_impl.h"
PD_REGISTER_KERNEL(fill,
GPU,
ALL_LAYOUT,
phi::FillKernel,
float,
double,
int64_t,
int,
paddle::platform::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fill_grad_kernel.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void FillGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const Scalar& value,
DenseTensor* in_grad) {
if (in_grad) {
dev_ctx.template Alloc<T>(in_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, in_grad, T(0));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fill_kernel.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void FillKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& value,
DenseTensor* out) {
T fill_var = value.to<T>();
PADDLE_ENFORCE_EQ(std::isnan(static_cast<double>(fill_var)),
false,
phi::errors::InvalidArgument("fill value should not be NaN,"
" but received NaN"));
dev_ctx.template Alloc<T>(out);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, out, fill_var);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FillOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("fill", {"X"}, {"value_float"}, {"Out"});
}
KernelSignature FillGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"fill_grad", {"Out@GRAD"}, {"value_float"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(fill_any, fill);
PD_REGISTER_BASE_KERNEL_NAME(fill_any_grad, fill_grad);
PD_REGISTER_ARG_MAPPING_FN(fill_any, phi::FillOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(fill_any_grad, phi::FillGradOpArgumentMapping);
......@@ -16,9 +16,11 @@ from __future__ import print_function
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode
import unittest
import numpy as np
from op_test import OpTest
from paddle.tensor.manipulation import fill_
class TestFillAnyOp(OpTest):
......@@ -75,5 +77,41 @@ class TestFillAnyOpvalue2(TestFillAnyOp):
self.value = 11111.1111
class TestFillAnyInplace(unittest.TestCase):
def test_fill_any_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.assertEqual(var.inplace_version, 0)
var.fill_(0)
self.assertEqual(var.inplace_version, 1)
var.fill_(0)
self.assertEqual(var.inplace_version, 2)
var.fill_(0)
self.assertEqual(var.inplace_version, 3)
def test_fill_any_eqaul(self):
with paddle.fluid.dygraph.guard():
tensor = paddle.to_tensor(
np.random.random((20, 30)).astype(np.float32))
target = tensor.numpy()
target[...] = 1
tensor.fill_(1)
self.assertEqual((tensor.numpy() == target).all().item(), True)
def test_backward(self):
with paddle.fluid.dygraph.guard():
x = paddle.full([10, 10], -1., dtype='float32')
x.stop_gradient = False
y = 2 * x
y.fill_(1)
y.backward()
self.assertTrue(np.array_equal(x.grad.numpy(), np.zeros([10, 10])))
if __name__ == "__main__":
unittest.main()
......@@ -777,8 +777,11 @@ def fill_(x, value):
raise TypeError(
"The type of 'value' must be int or float, but received %s." %
(type(value)))
return _C_ops.fill_any_(x, "value_float", float(value), "value_int",
int(value))
if in_dygraph_mode():
return _C_ops.final_state_fill_(x, value)
else:
return _C_ops.fill_any_(x, "value_float", float(value), "value_int",
int(value))
@dygraph_only
......@@ -806,7 +809,10 @@ def zero_(x):
print(tensor.tolist()) #[0, 0, 0, 0, 0]
"""
return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0))
if in_dygraph_mode():
return _C_ops.final_state_fill_(x, 0.)
else:
return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0))
@dygraph_only
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册