未验证 提交 efeec79b 编写于 作者: Z zhiboniu 提交者: GitHub

add paddle.Tensor api fill_(inplace), zero_(inplace) (#33829)

add fill_ backward
上级 cb7d8595
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_op.h"
namespace paddle {
namespace operators {
class FillAnyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input tensor.");
AddOutput("Out", "Tensor, the tensor filled with input value ");
AddAttr<float>("value_float", "The float var to fill in Tensor")
.SetDefault(0);
AddAttr<int>("value_int", "The int var to fill in Tensor").SetDefault(0);
AddComment(R"DOC(Fill operator with backward;
Fill an tensor with `value`.
)DOC");
};
};
class FillAnyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillAny");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillAny");
auto x_dims = context->GetInputDim("X");
context->SetOutputDim("Out", x_dims);
}
};
class FillAnyGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "mul");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
template <typename T>
class FillAnyGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType(this->ForwardOpType() + "_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(FillAnyOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FillAnyGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fill_any, ops::FillAnyOp, ops::FillAnyOpMaker,
ops::FillAnyGradOpMaker<paddle::framework::OpDesc>,
ops::FillAnyGradOpMaker<paddle::imperative::OpBase>,
ops::FillAnyOpInplaceInferer);
REGISTER_OPERATOR(fill_any_grad, ops::FillAnyGradOp,
ops::FillAnyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
fill_any, ops::FillAnyKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
fill_any_grad,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyGradKernel<paddle::platform::CPUDeviceContext, bool>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_any, ops::FillAnyKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
fill_any_grad,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyGradKernel<paddle::platform::CUDADeviceContext, bool>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FillAnyKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *out = ctx.Output<framework::Tensor>("Out");
auto floatvar = ctx.template Attr<float>("value_float");
auto intvar = ctx.template Attr<int>("value_int");
auto isfloat = ((typeid(float) == typeid(T)) ||
(typeid(double) == typeid(T) ||
typeid(paddle::platform::float16) == typeid(T)));
T fill_var = static_cast<T>(floatvar);
if (!isfloat) {
fill_var = static_cast<T>(intvar);
}
PADDLE_ENFORCE_EQ(
std::isnan(static_cast<double>(fill_var)), false,
platform::errors::InvalidArgument("fill value should not be NaN,"
" but received NaN"));
out->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> functor;
functor(reinterpret_cast<const DeviceContext &>(dev_ctx), out,
static_cast<T>(fill_var));
}
};
template <typename DeviceContext, typename T>
class FillAnyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> functor;
functor(reinterpret_cast<const DeviceContext &>(dev_ctx), dx, T(0));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -276,6 +276,8 @@ from .device import get_device # noqa: F401
from .fluid.framework import is_compiled_with_cuda # noqa: F401
from .fluid.framework import is_compiled_with_rocm # noqa: F401
from .fluid.framework import disable_signal_handler # noqa: F401
from .fluid.framework import get_flags # noqa: F401
from .fluid.framework import set_flags # noqa: F401
from .device import is_compiled_with_xpu # noqa: F401
from .device import is_compiled_with_npu # noqa: F401
from .device import XPUPlace # noqa: F401
......@@ -521,5 +523,7 @@ __all__ = [ # noqa
'standard_normal',
'diagonal',
'broadcast_tensors',
'einsum'
'einsum',
'set_flags',
'get_flags'
]
......@@ -6273,6 +6273,7 @@ def device_guard(device=None):
def set_flags(flags):
"""
This function sets the GFlags value in Paddle.
For FLAGS please refer to :ref:`en_guides_flags_flags`
Args:
flags (dict): A dict contains flags and its value.
......@@ -6280,8 +6281,8 @@ def set_flags(flags):
Examples:
.. code-block:: python
import paddle.fluid as fluid
fluid.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0})
import paddle
paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0})
"""
if not isinstance(flags, dict):
raise TypeError('flags in set_flags should be a dict')
......@@ -6296,6 +6297,7 @@ def set_flags(flags):
def get_flags(flags):
"""
This function gets the GFlags value in Paddle.
For FLAGS please refer to :ref:`en_guides_flags_flags`
Args:
flags(list|tuple|str): A list/tuple of string or a string which is the flag's name.
......@@ -6306,10 +6308,10 @@ def get_flags(flags):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
flags = ['FLAGS_eager_delete_tensor_gb', 'FLAGS_check_nan_inf']
res = fluid.get_flags(flags)
res = paddle.get_flags(flags)
print(res)
# {'FLAGS_eager_delete_tensor_gb': 0.0, 'FLAGS_check_nan_inf': False}
"""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid.core as core
import unittest
import numpy as np
from op_test import OpTest
class TestFillAnyOp(OpTest):
def setUp(self):
self.op_type = "fill_any"
self.dtype = 'float64'
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((20, 30)).astype(self.dtype)}
self.attrs = {
'value_float': float(self.value),
'value_int': int(self.value)
}
self.outputs = {
'Out':
self.value * np.ones_like(self.inputs["X"]).astype(self.dtype)
}
def init(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestFillAnyOpFloat32(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 0.0
class TestFillAnyOpFloat16(TestFillAnyOp):
def init(self):
self.dtype = np.float16
class TestFillAnyOpvalue1(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 111111555
class TestFillAnyOpvalue2(TestFillAnyOp):
def init(self):
self.dtype = np.float32
self.value = 11111.1111
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import six
import paddle
class TensorFill_Test(unittest.TestCase):
def setUp(self):
self.shape = [32, 32]
def test_tensor_fill_true(self):
typelist = ['float32', 'float64', 'int32', 'int64', 'float16']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
places.append(fluid.CUDAPinnedPlace())
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
np_arr = np.reshape(
np.array(six.moves.range(np.prod(self.shape))), self.shape)
for dtype in typelist:
var = 1.
tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype)
newtensor = tensor.clone()
newtensor[...] = var
tensor.fill_(var) #var type is basic type in typelist
self.assertEqual((tensor.numpy() == newtensor.numpy()).all(),
True)
def test_tensor_fill_backward(self):
typelist = ['float32']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
places.append(fluid.CUDAPinnedPlace())
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
np_arr = np.reshape(
np.array(six.moves.range(np.prod(self.shape))), self.shape)
for dtype in typelist:
var = int(1)
tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype)
tensor.stop_gradient = False
y = tensor * 2
y.fill_(var)
loss = y.sum()
loss.backward()
self.assertEqual((y.grad.numpy() == 0).all().item(), True)
def test_errors(self):
def test_list():
x = paddle.to_tensor([2, 3, 4])
x.fill_([1])
self.assertRaises(TypeError, test_list)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import six
import paddle
class TensorFill_Test(unittest.TestCase):
def setUp(self):
self.shape = [32, 32]
def test_tensor_fill_true(self):
typelist = ['float32', 'float64', 'int32', 'int64', 'float16']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
places.append(fluid.CUDAPinnedPlace())
for p in places:
np_arr = np.reshape(
np.array(six.moves.range(np.prod(self.shape))), self.shape)
for dtype in typelist:
tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype)
newtensor = tensor.clone()
newtensor[...] = 0
tensor.zero_()
self.assertEqual(
(tensor.numpy() == newtensor.numpy()).all().item(), True)
if __name__ == '__main__':
unittest.main()
......@@ -29,6 +29,7 @@ from ..fluid.layers import unstack # noqa: F401
from ..fluid.layers import scatter_nd # noqa: F401
from ..fluid.layers import shard_index # noqa: F401
from ..fluid.layers.nn import _elementwise_op_in_dygraph
from ..fluid import layers
from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only
import paddle
......@@ -37,6 +38,74 @@ from paddle import _C_ops
__all__ = []
@dygraph_only
def fill_(x, value):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
This function fill the Tensor with value inplace.
Args:
x(Tensor): ``x`` is the Tensor we want to filled data inplace
value(Scale): ``value`` is the value to be filled in x
Returns:
x(Tensor): Tensor x filled with value inplace
Examples:
.. code-block:: python
import paddle
tensor = paddle.to_tensor([0, 1, 2, 3, 4])
tensor.fill_(0)
print(tensor.tolist()) #[0, 0, 0, 0, 0]
"""
if not isinstance(value, (float, int)):
raise TypeError(
"The type of 'value' must be int or float, but received %s." %
(type(value)))
return core.ops.fill_any_(x, "value_float",
float(value), "value_int", int(value))
setattr(core.VarBase, 'fill_', fill_)
@dygraph_only
def zero_(x):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
This function fill the Tensor with zero inplace.
Args:
x(Tensor): ``x`` is the Tensor we want to filled with zero inplace
Returns:
x(Tensor): Tensor x filled with zero inplace
Examples:
.. code-block:: python
import paddle
tensor = paddle.to_tensor([0, 1, 2, 3, 4])
tensor.zero_()
print(tensor.tolist()) #[0, 0, 0, 0, 0]
"""
return core.ops.fill_any_(x, "value_float", 0., "value_int", int(0))
setattr(core.VarBase, 'zero_', zero_)
@dygraph_only
def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
"""
......
......@@ -722,5 +722,6 @@ STATIC_MODE_TESTING_LIST = [
'test_c_embedding_op',
'test_class_center_sample_op',
'test_fill_diagonal_tensor_op',
'test_fill_any_op',
'test_margin_cross_entropy_op',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册