From c1b4d1c15d228a5fbfe40374d65734d9d4a8991f Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Thu, 23 Apr 2020 10:36:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90cherry-pick=E3=80=91add=20diag=5Fembed?= =?UTF-8?q?=20op=20(#23385)=20(#24001)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add diag_embed op (#23385) * add diag_embed op, test=release/2.0-beta * solved a conflict, test=release/2.0-beta --- paddle/fluid/operators/diag_embed_op.cc | 113 ++++++++++++++++ paddle/fluid/operators/diag_embed_op.cu | 26 ++++ paddle/fluid/operators/diag_embed_op.h | 121 ++++++++++++++++++ .../fluid/tests/unittests/test_diag_embed.py | 73 +++++++++++ python/paddle/nn/__init__.py | 5 +- python/paddle/nn/functional/__init__.py | 10 +- python/paddle/nn/functional/extension.py | 86 ++++++++++++- 7 files changed, 427 insertions(+), 7 deletions(-) create mode 100644 paddle/fluid/operators/diag_embed_op.cc create mode 100644 paddle/fluid/operators/diag_embed_op.cu create mode 100644 paddle/fluid/operators/diag_embed_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_diag_embed.py diff --git a/paddle/fluid/operators/diag_embed_op.cc b/paddle/fluid/operators/diag_embed_op.cc new file mode 100644 index 00000000000..6d8bc4d219e --- /dev/null +++ b/paddle/fluid/operators/diag_embed_op.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/diag_embed_op.h" + +namespace paddle { +namespace operators { + +class DiagEmbedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("Input"), true, + platform::errors::NotFound("Input of DiagEmbedOp is not found.")); + + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output of DiagEmbedOp is not found.")); + + int offset = ctx->Attrs().Get("offset"); + int dim1 = ctx->Attrs().Get("dim1"); + int dim2 = ctx->Attrs().Get("dim2"); + + auto x_dims = ctx->GetInputDim("Input"); + + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; + int offset_ = std::abs(offset); + + PADDLE_ENFORCE_LE( + dim1_, x_dims.size(), + platform::errors::OutOfRange( + "Dim1 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), x_dims.size(), dim1)); + PADDLE_ENFORCE_LE( + dim2_, x_dims.size(), + platform::errors::OutOfRange( + "Dim2 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), x_dims.size(), dim2)); + PADDLE_ENFORCE_NE(dim1_, dim2_, + platform::errors::InvalidArgument( + "diagonal dimensions should not be identical " + "%ld vs %ld.", + dim1, dim2)); + + int new_dim_len = offset_ + x_dims[x_dims.size() - 1]; + auto sizes = vectorize(x_dims); + sizes.pop_back(); + sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len); + sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len); + ctx->SetOutputDim("Out", framework::make_ddim(sizes)); + } +}; + +class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "The input tensor. Must be at least 1-dimensional."); + AddOutput("Out", "A matrix whose certain 2D planes is diagonal matrix."); + + AddAttr( + "offset", + R"DOC((int, default 0), which diagonal to consider. Default: 0 (main diagonal). + )DOC") + .SetDefault(0); + AddAttr( + "dim1", + R"DOC((int, default -2), first dimension with respect to which to take diagonal. Default: -2. + )DOC") + .SetDefault(-2); + AddAttr( + "dim2", + R"DOC((int, default -1), second dimension with respect to which to take diagonal. Default: -1. + )DOC") + .SetDefault(-1); + + AddComment(R"DOC(Creates a tensor whose diagonals of certain 2D planes + (specified by dim1 and dim2) are filled by input. + To facilitate creating batched diagonal matrices, + the 2D planes formed by the last two dimensions of the returned tensor + are chosen by default. + )DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace platform = paddle::platform; +REGISTER_OPERATOR( + diag_embed, ops::DiagEmbedOp, ops::DiagEmbedOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + diag_embed, ops::DiagEmbedKernel, + ops::DiagEmbedKernel, + ops::DiagEmbedKernel, + ops::DiagEmbedKernel); diff --git a/paddle/fluid/operators/diag_embed_op.cu b/paddle/fluid/operators/diag_embed_op.cu new file mode 100644 index 00000000000..2e03622e10f --- /dev/null +++ b/paddle/fluid/operators/diag_embed_op.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/diag_embed_op.h" + +namespace ops = paddle::operators; +namespace platform = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + diag_embed, ops::DiagEmbedKernel, + ops::DiagEmbedKernel, + ops::DiagEmbedKernel, + ops::DiagEmbedKernel, + ops::DiagEmbedKernel); diff --git a/paddle/fluid/operators/diag_embed_op.h b/paddle/fluid/operators/diag_embed_op.h new file mode 100644 index 00000000000..8c4c68fb1ff --- /dev/null +++ b/paddle/fluid/operators/diag_embed_op.h @@ -0,0 +1,121 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct DiagEmbedFunctor { + DiagEmbedFunctor(const T* input, int64_t numel, const int64_t* dim, + int64_t offset, int64_t dims_size, T* output, + const int64_t* strides) + : input_(input), + numel_(numel), + dim_(dim), + offset_(offset), + dims_size_(dims_size), + output_(output), + strides_(strides) {} + + HOSTDEVICE void operator()(size_t idx) const { + int64_t position = 0; + auto numel = numel_; + int64_t num = idx; + for (int64_t i = 0; i < dims_size_; i++) { + numel = numel / dim_[i]; + position += num / numel * strides_[i]; + num = num % numel; + } + output_[position + offset_] = input_[idx]; + } + + const T* input_; + int64_t numel_; + const int64_t* dim_; + int64_t offset_; + int64_t dims_size_; + T* output_; + const int64_t* strides_; +}; + +template +class DiagEmbedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* out = context.Output("Out"); + + const int64_t offset = context.Attr("offset"); + const int64_t dim1 = context.Attr("dim1"); + const int64_t dim2 = context.Attr("dim2"); + auto* input_data = input->data(); + + T* out_data = out->mutable_data(context.GetPlace()); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, out, static_cast(0.0)); + + auto out_dims = out->dims(); + int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1; + int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2; + auto stride = framework::stride(out_dims); + int64_t diag_size; + int64_t storage_offset = 0; + if (offset >= 0) { + int64_t dim = out_dims[dim2_] - offset; + diag_size = std::max(std::min(out_dims[dim1_], dim), 0); + } else { + int64_t dim = out_dims[dim1_] + offset; + diag_size = std::max(std::min(dim, out_dims[dim2_]), 0); + } + if (diag_size == 0) { + // skip + } else if (offset >= 0) { + storage_offset += offset * stride[dim2_]; + } else { + storage_offset -= offset * stride[dim1_]; + } + auto strides = vectorize(stride); + strides.erase(strides.begin() + std::max(dim1_, dim2_)); + strides.erase(strides.begin() + std::min(dim1_, dim2_)); + strides.push_back(stride[dim1_] + stride[dim2_]); + const auto dims = vectorize(input->dims()); + +#ifdef __NVCC__ + thrust::device_vector dims_vec(dims); + const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data()); + thrust::device_vector strides_vec(strides); + const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data()); +#else + const int64_t* dims_arr = dims.data(); + const int64_t* strides_arr = strides.data(); +#endif + + platform::ForRange for_range(dev_ctx, input->numel()); + DiagEmbedFunctor functor(input_data, input->numel(), dims_arr, + storage_offset, dims.size(), out_data, + strides_arr); + for_range(functor); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_diag_embed.py b/python/paddle/fluid/tests/unittests/test_diag_embed.py new file mode 100644 index 00000000000..9df8fc7d575 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_diag_embed.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.dygraph as dg +import paddle.fluid.core as core + + +class TestDiagEmbedOp(OpTest): + def setUp(self): + self.op_type = "diag_embed" + self.init_config() + self.outputs = {'Out': self.target} + + def test_check_output(self): + self.check_output() + + def init_config(self): + self.case = np.random.randn(2, 3).astype('float32') + self.inputs = {'Input': self.case} + self.attrs = {'offset': 0, 'dim1': -2, 'dim2': -1} + self.target = np.stack([np.diag(r, 0) for r in self.inputs['Input']], 0) + + +class TestDiagEmbedOpCase1(TestDiagEmbedOp): + def init_config(self): + self.case = np.random.randn(2, 3).astype('float32') + self.inputs = {'Input': self.case} + self.attrs = {'offset': -1, 'dim1': 0, 'dim2': 2} + self.target = np.stack([np.diag(r, -1) for r in self.inputs['Input']], + 1) + + +class TestDiagEmbedAPICase(unittest.TestCase): + def test_case1(self): + diag_embed = np.random.randn(2, 3, 4).astype('float32') + data1 = fluid.data(name='data1', shape=[2, 3, 4], dtype='float32') + out1 = F.diag_embed(data1) + out2 = F.diag_embed(data1, offset=1, dim1=-2, dim2=3) + + place = core.CPUPlace() + exe = fluid.Executor(place) + results = exe.run(fluid.default_main_program(), + feed={"data1": diag_embed}, + fetch_list=[out1, out2], + return_numpy=True) + target1 = np.stack( + [np.stack([np.diag(s, 0) for s in r], 0) for r in diag_embed], 0) + target2 = np.stack( + [np.stack([np.diag(s, 1) for s in r], 0) for r in diag_embed], 0) + self.assertTrue(np.allclose(results[0], target1)) + self.assertTrue(np.allclose(results[1], target2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index ca755fdb725..91a107b3a19 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -14,12 +14,13 @@ # TODO: import all neural network related api under this directory, # including layers, linear, conv, rnn etc. -# __all__ = [] from .layer import norm +from .functional import extension __all__ = [] __all__ += norm.__all__ +__all__ += extension.__all__ # TODO: define alias in nn directory # from .clip import ErrorClipByValue #DEFINE_ALIAS @@ -220,7 +221,7 @@ from .functional.extension import row_conv #DEFINE_ALIAS # from .functional.extension import target_assign #DEFINE_ALIAS # from .functional.extension import temporal_shift #DEFINE_ALIAS # from .functional.extension import warpctc #DEFINE_ALIAS -# from .functional.extension import diag_embed #DEFINE_ALIAS +from .functional.extension import diag_embed #DEFINE_ALIAS # from .functional.rnn import gru_unit #DEFINE_ALIAS # from .functional.rnn import lstm #DEFINE_ALIAS # from .functional.rnn import lstm_unit #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index b516946aa74..58393fdcfd5 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -14,10 +14,11 @@ # TODO: import all neural network related api under this directory, # including layers, linear, conv, rnn etc. -# __all__ = [ ] +__all__ = [] # TODO: define alias in functional directory from . import conv +__all__ += conv.__all__ from .conv import conv2d #DEFINE_ALIAS from .conv import conv2d_transpose #DEFINE_ALIAS from .conv import conv3d #DEFINE_ALIAS @@ -103,6 +104,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS # from .vision import yolo_box #DEFINE_ALIAS # from .vision import yolov3_loss #DEFINE_ALIAS from . import activation +__all__ += activation.__all__ # from .activation import brelu #DEFINE_ALIAS # from .activation import elu #DEFINE_ALIAS # from .activation import erf #DEFINE_ALIAS @@ -128,6 +130,8 @@ from .activation import sigmoid #DEFINE_ALIAS # from .activation import tanh_shrink #DEFINE_ALIAS # from .activation import thresholded_relu #DEFINE_ALIAS from .activation import log_softmax #DEFINE_ALIAS +from . import extension +__all__ += extension.__all__ # from .extension import add_position_encoding #DEFINE_ALIAS # from .extension import autoincreased_step_counter #DEFINE_ALIAS # from .extension import continuous_value_model #DEFINE_ALIAS @@ -143,7 +147,7 @@ from .extension import row_conv #DEFINE_ALIAS # from .extension import target_assign #DEFINE_ALIAS # from .extension import temporal_shift #DEFINE_ALIAS # from .extension import warpctc #DEFINE_ALIAS -# from .extension import diag_embed #DEFINE_ALIAS +from .extension import diag_embed #DEFINE_ALIAS # from .rnn import gru_unit #DEFINE_ALIAS # from .rnn import lstm #DEFINE_ALIAS # from .rnn import lstm_unit #DEFINE_ALIAS @@ -176,6 +180,8 @@ from .extension import row_conv #DEFINE_ALIAS # from .lod import dynamic_gru #DEFINE_ALIAS # from .lod import dynamic_lstm #DEFINE_ALIAS # from .lod import dynamic_lstmp #DEFINE_ALIAS +from . import common +#__all__ += common.__all__ # from .common import dropout #DEFINE_ALIAS # from .common import embedding #DEFINE_ALIAS # from .common import fc #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index e00349c8cda..77687041083 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -29,15 +29,95 @@ __all__ = [ # 'target_assign', # 'temporal_shift', # 'warpctc', - # 'diag_embed' + 'diag_embed' ] -from ...fluid import core, dygraph_utils -from ...fluid.framework import in_dygraph_mode +import numpy as np +from ...fluid.data_feeder import check_dtype from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import Variable, in_dygraph_mode +from ...fluid.layers.tensor import assign +from ...fluid import core, dygraph_utils from ...fluid.layers.layer_function_generator import templatedoc +def diag_embed(input, offset=0, dim1=-2, dim2=-1): + """ + This OP creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) + are filled by ``input``. By default, a 2D plane formed by the last two dimensions + of the returned tensor will be selected. + The argument ``offset`` determines which diagonal is generated: + - If offset = 0, it is the main diagonal. + - If offset > 0, it is above the main diagonal. + - If offset < 0, it is below the main diagonal. + Args: + input(Variable|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64. + offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal). + dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2. + dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1. + Returns: + Variable, the output data type is the same as input data type. + Examples: + .. code-block:: python + import paddle.nn.functional as F + import paddle.fluid.dygraph as dg + import numpy as np + + diag_embed = np.random.randn(2, 3).astype('float32') + with dg.guard(): + data1 = F.diag_embed(diag_embed) + data2 = F.diag_embed(diag_embed, offset=1, dim1=0, dim2=2) + """ + inputs = {'Input': [input]} + attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} + + if not isinstance(input, Variable): + input = assign(input) + + def __check_input(input, offset, dim1, dim2): + check_dtype(input.dtype, 'Input', + ['int32', 'int64', 'float16', 'float32', 'float64'], + 'diag_embed') + + input_shape = list(input.shape) + assert (len(input_shape) >= 1, \ + "Input must be at least 1-dimensional, " \ + "But received Input's dimensional: %s.\n" % \ + len(input_shape)) + + assert ( + np.abs(dim1) <= len(input_shape), + "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" + % (-(len(input_shape) + 1), len(input_shape), dim1)) + + assert ( + np.abs(dim2) <= len(input_shape), + "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" + % (-(len(input_shape) + 1), len(input_shape), dim2)) + + dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + 1 + dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + 1 + assert ( dim1_ != dim2_, + "dim1 and dim2 cannot be the same dimension." \ + "But received dim1 = %d, dim2 = %d\n"%(dim1, dim2)) + + if not in_dygraph_mode(): + __check_input(input, offset, dim1, dim2) + helper = LayerHelper("diag_embed", **locals()) + + out = helper.create_variable_for_type_inference(dtype=input.dtype) + + helper.append_op( + type='diag_embed', + inputs={'Input': [input]}, + attrs={'offset': offset, + 'dim1': dim1, + 'dim2': dim2}, + outputs={'Out': [out]}) + out.stop_gradient = True + return out + + @templatedoc() def row_conv(input, weight, act=None): """ -- GitLab