提交 d3b3443d 编写于 作者: Z zhoukunsheng 提交者: Tao Luo

add ones_like op (#17388)

上级 67b48d7f
......@@ -300,6 +300,7 @@ paddle.fluid.layers.isfinite (ArgSpec(args=['x'], varargs=None, keywords=None, d
paddle.fluid.layers.range (ArgSpec(args=['start', 'end', 'step', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', 'a45b42f21bc5a4e84b60981a3d629ab3'))
paddle.fluid.layers.linspace (ArgSpec(args=['start', 'stop', 'num', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '3663d1148946eed4c1c34c81be586b9e'))
paddle.fluid.layers.zeros_like (ArgSpec(args=['x', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd88a23bcdc443719b3953593f7cef14a'))
paddle.fluid.layers.ones_like (ArgSpec(args=['x', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', '642afd126553337d6796600e886a6525'))
paddle.fluid.layers.diag (ArgSpec(args=['diagonal'], varargs=None, keywords=None, defaults=None), ('document', '88a15e15f0098d549f07a01eaebf9ce3'))
paddle.fluid.layers.While ('paddle.fluid.layers.control_flow.While', ('document', '50110155608a00f43d3d3fd1be41dcb4'))
paddle.fluid.layers.While.__init__ (ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_any_like_op.h"
namespace paddle {
namespace operators {
class FillAnyLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FillAnyLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillAnyLikeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with specified value.");
AddAttr<float>("value", "The filled value").SetDefault(0.0);
AddComment(R"DOC(
FillAnyLike Operator.
Fill up a variable with Attr(value).
The output will have the same shape and dtype as the input.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_any_like, ops::FillAnyLikeOp,
ops::FillAnyLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, bool>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_any_like_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, bool>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <limits>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FillAnyLikeKernel : public framework::OpKernel<T> {
public:
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, platform::float16>::value,
float, T>::type>::type;
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
// TODO(fangzeyang): Once context.Attribute supports double dtype, this
// kernel should be updated to support double dtype, too.
float value = context.Attr<float>("value");
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
"filled value is out of range for targeted type in fill_any_like "
"kernel");
PADDLE_ENFORCE(!std::isnan(value), "filled value is NaN");
math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), out,
static_cast<T>(value));
}
};
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,7 @@ __all__ = [
'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite',
'range', 'linspace', 'zeros_like', 'diag'
'range', 'linspace', 'zeros_like', 'ones_like', 'diag'
]
......@@ -989,3 +989,38 @@ def diag(diagonal):
out.stop_gradient = True
return out
def ones_like(x, out=None):
"""
**ones_like**
This function creates a ones tensor which has identical shape and dtype
with `x`.
Args:
x(Variable): The input tensor which specifies shape and dtype.
out(Variable): The output tensor.
Returns:
x(Variable): The tensor variable storing the output.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', dtype='float32', shape=[3], append_batch_size=False)
data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0]
"""
helper = LayerHelper("ones_like", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 1.0},
outputs={'Out': [out]})
return out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid.core as core
import paddle.compat as cpt
import unittest
import numpy as np
from op_test import OpTest
class TestFillAnyLikeOp(OpTest):
def setUp(self):
self.op_type = "fill_any_like"
self.dtype = np.int32
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.attrs = {'value': self.value}
self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])}
def init(self):
pass
def test_check_output(self):
self.check_output()
class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
def init(self):
self.dtype = np.float32
self.value = 0.0
class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
def init(self):
self.value = 1.0
class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
def init(self):
self.value = 1e-10
class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
def init(self):
self.value = 1e-100
class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp):
def init(self):
self.value = 1e100
def test_check_output(self):
exception = None
try:
self.check_output()
except core.EnforceNotMet as ex:
exception = ex
self.assertIsNotNone(exception)
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
def init(self):
self.dtype = np.float16
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册