diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index a5bd3cd922070deb2c91e1485874c9b3d4e8d278..88a8ed06207d9969b042bf9d8cd4d2b3b969042d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -309,6 +309,7 @@ paddle.fluid.layers.linspace (ArgSpec(args=['start', 'stop', 'num', 'dtype'], va paddle.fluid.layers.zeros_like (ArgSpec(args=['x', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd88a23bcdc443719b3953593f7cef14a')) paddle.fluid.layers.ones_like (ArgSpec(args=['x', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', '642afd126553337d6796600e886a6525')) paddle.fluid.layers.diag (ArgSpec(args=['diagonal'], varargs=None, keywords=None, defaults=None), ('document', '88a15e15f0098d549f07a01eaebf9ce3')) +paddle.fluid.layers.eye (ArgSpec(args=['num_rows', 'num_columns', 'batch_shape', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 'float32')), ('document', '25389d1e239a5d1cda66298f908ec549')) paddle.fluid.layers.While ('paddle.fluid.layers.control_flow.While', ('document', '50110155608a00f43d3d3fd1be41dcb4')) paddle.fluid.layers.While.__init__ (ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.While.block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/operators/eye_op.cc b/paddle/fluid/operators/eye_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..40848b963350202b684dbfb7625eb8d4427bdb4a --- /dev/null +++ b/paddle/fluid/operators/eye_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eye_op.h" + +namespace paddle { +namespace operators { + +class EyeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of EyeOP should not be null."); + auto num_rows = ctx->Attrs().Get("num_rows"); + PADDLE_ENFORCE(num_rows >= 0, + "The value of Input(num_rows) should be non-negative int."); + auto num_columns = ctx->Attrs().Get("num_columns"); + if (num_columns == -1) num_columns = num_rows; + PADDLE_ENFORCE( + num_columns >= 0, + "The value of Input(num_columns) should be non-negative int."); + ctx->SetOutputDim("Out", {num_rows, num_columns}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class EyeOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + auto data_type = static_cast( + boost::get(ctx->GetAttr("dtype"))); + auto& out_var_name = ctx->Output("Out").front(); + ctx->SetDataType(out_var_name, data_type); + } +}; + +class EyeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddAttr("dtype", + "(int, default 5 (FP32)) " + "Output data type") + .SetDefault(framework::proto::VarType::FP32); + AddAttr("num_rows", + "(int64_t) the number of rows in output tensor"); + AddAttr("num_columns", + "(int64_t) the number of columns in output tensor." + "Default -1 means that num_columns=num_rows") + .SetDefault(-1); + AddOutput("Out", + "(Tensor) Construct an identity tensor with " + "specified shape [num_rows, num_columns]"); + AddComment(R"DOC( +Return an identity tensor whose shape is [num_rows, num_columns]. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; +using float16 = paddle::platform::float16; + +REGISTER_OPERATOR(eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(eye, ops::EyeKernel, + ops::EyeKernel, + ops::EyeKernel, ops::EyeKernel, + ops::EyeKernel); diff --git a/paddle/fluid/operators/eye_op.cu b/paddle/fluid/operators/eye_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d55235a54c70b1a4db4bd7f355332c923207591 --- /dev/null +++ b/paddle/fluid/operators/eye_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eye_op.h" + +namespace ops = paddle::operators; +namespace plf = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + eye, ops::EyeKernel, + ops::EyeKernel, + ops::EyeKernel, + ops::EyeKernel, + ops::EyeKernel); diff --git a/paddle/fluid/operators/eye_op.h b/paddle/fluid/operators/eye_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0eefe7d2163bb967596480f2427b995a6a87ff6e --- /dev/null +++ b/paddle/fluid/operators/eye_op.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct EyeFunctor { + EyeFunctor(int64_t num_columns, T* output) + : num_columns_(num_columns), output_(output) {} + + HOSTDEVICE void operator()(size_t idx) const { + output_[idx * num_columns_ + idx] = static_cast(1); + } + + int64_t num_columns_; + T* output_; +}; + +template +class EyeKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto num_rows = ctx.Attr("num_rows"); + auto num_columns = ctx.Attr("num_columns"); + if (num_columns == -1) num_columns = num_rows; + + auto* out_tensor = ctx.Output("Out"); + T* out_data = out_tensor->mutable_data(ctx.GetPlace()); + + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, out_tensor, static_cast(0)); + + int64_t num_eyes = std::min(num_rows, num_columns); + platform::ForRange for_range(dev_ctx, num_eyes); + EyeFunctor functor(num_columns, out_data); + for_range(functor); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 81972ae798067d5ab6f34b677839e97ddf099121..24ad4a2b3f3e6217bfb5fdcad69a4ffea84b6e4e 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -28,7 +28,7 @@ __all__ = [ 'tensor_array_to_tensor', 'concat', 'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite', - 'range', 'linspace', 'zeros_like', 'ones_like', 'diag' + 'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye' ] @@ -991,6 +991,77 @@ def diag(diagonal): return out +def eye(num_rows, num_columns=None, batch_shape=None, dtype='float32'): + """ + **eye** + + This function constructs an identity tensor, or a batch of tensor. + + Args: + num_rows(int): the number of rows in each batch tensor. + num_columns(int): the number of columns in each batch tensor. + If None, default: num_rows. + batch_shape(list(int)): If provided, the returned tensor will have a leading + batch size of this shape. + dtype(string): 'float32'|'int32'|..., the data type of the returned tensor. + + Returns: + Variable: An identity tensor of shape batch_shape + [num_rows, num_columns]. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + data = fluid.layers.eye(3, dtype='int32') + # [[1, 0, 0] + # [0, 1, 0] + # [0, 0, 1]] + + data = fluid.layers.eye(2, 3, dtype='int32') + # [[1, 0, 0] + # [0, 1, 0]] + + data = fluid.layers.eye(2, batch_shape=[3]) + # Construct a batch of 3 identity tensors, each 2 x 2. + # data[i, :, :] is a 2 x 2 identity tensor, i = 0, 1, 2. + + """ + + helper = LayerHelper("eye", **locals()) + if not isinstance(num_rows, int) or num_rows < 0: + raise TypeError("num_rows should be a non-negative int") + if num_columns is not None: + if not isinstance(num_columns, int) or num_columns < 0: + raise TypeError("num_columns should be a non-negative int") + else: + num_columns = num_rows + out = helper.create_variable_for_type_inference(dtype=dtype) + c_dtype = convert_np_dtype_to_dtype_(dtype) + helper.append_op( + type='eye', + inputs={}, + outputs={'Out': [out]}, + attrs={ + 'num_rows': num_rows, + 'num_columns': num_columns, + 'dtype': c_dtype + }, + stop_gradient=True) + out.stop_gradient = True + + if batch_shape is not None: + if not isinstance(batch_shape, list): + raise TypeError("batch_shape should be a list") + from .nn import stack + for batch_val in reversed(batch_shape): + if batch_val <= 0: + raise TypeError("batch_shape should be a positive int list") + else: + stack_vars = [out for _ in numpy.arange(batch_val)] + out = stack(stack_vars, axis=0) + return out + + def ones_like(x, out=None): """ **ones_like** diff --git a/python/paddle/fluid/tests/unittests/test_eye_op.py b/python/paddle/fluid/tests/unittests/test_eye_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ea37584b6a5e1d72badc65c294898bdf08f32a2a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eye_op.py @@ -0,0 +1,74 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle.fluid.framework as framework + + +class TestEyeOp(OpTest): + def setUp(self): + ''' + Test eye op with specified shape + ''' + self.op_type = "eye" + + self.inputs = {} + self.attrs = { + 'num_rows': 219, + 'num_columns': 319, + 'dtype': framework.convert_np_dtype_to_dtype_(np.int32) + } + self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)} + + def test_check_output(self): + self.check_output() + + +class TestEyeOp1(OpTest): + def setUp(self): + ''' + Test eye op with default parameters + ''' + self.op_type = "eye" + + self.inputs = {} + self.attrs = {'num_rows': 50} + self.outputs = {'Out': np.eye(50, dtype=float)} + + def test_check_output(self): + self.check_output() + + +class TestEyeOp2(OpTest): + def setUp(self): + ''' + Test eye op with specified shape + ''' + self.op_type = "eye" + + self.inputs = {} + self.attrs = {'num_rows': 99, 'num_columns': 1} + self.outputs = {'Out': np.eye(99, 1, dtype=float)} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index a4e51d6cfea1c3dd20516f4f9f1d76ff6492f91c..ce1305bfc2910e05a4d45e0644cd7fd892cf2b01 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -872,6 +872,37 @@ class TestLayer(LayerTest): self.assertTrue(np.allclose(static_rlt2, static_rlt)) self.assertTrue(np.allclose(dy_rlt.numpy(), static_rlt)) + def test_eye_op(self): + np_eye = np.eye(3, 2) + array_rlt1 = [np_eye for _ in range(3)] + stack_rlt1 = np.stack(array_rlt1, axis=0) + array_rlt2 = [stack_rlt1 for _ in range(4)] + stack_rlt2 = np.stack(array_rlt2, axis=0) + + with self.dynamic_graph(): + eye_tensor = layers.eye(num_rows=3, num_columns=2) + eye_tensor_rlt1 = layers.eye(num_rows=3, + num_columns=2, + batch_shape=[3]) + eye_tensor_rlt2 = layers.eye(num_rows=3, + num_columns=2, + batch_shape=[4, 3]) + diag_tensor = layers.eye(20) + + self.assertTrue(np.allclose(eye_tensor.numpy(), np_eye)) + self.assertTrue(np.allclose(eye_tensor_rlt1.numpy(), stack_rlt1)) + self.assertTrue(np.allclose(eye_tensor_rlt2.numpy(), stack_rlt2)) + self.assertTrue(np.allclose(diag_tensor.numpy(), np.eye(20))) + + with self.assertRaises(TypeError): + layers.eye(num_rows=3.1) + with self.assertRaises(TypeError): + layers.eye(num_rows=3, num_columns=2.2) + with self.assertRaises(TypeError): + layers.eye(num_rows=3, batch_shape=2) + with self.assertRaises(TypeError): + layers.eye(num_rows=3, batch_shape=[-1]) + class TestBook(LayerTest): def test_all_layers(self):