From 5de576b0af93519236a2307855b1182c86c5d142 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Tue, 17 Aug 2021 11:22:55 +0800 Subject: [PATCH] add api fill_diagonal_inplace (#34460) --- paddle/fluid/operators/fill_diagonal_op.cc | 217 ++++++++++++++++++ paddle/fluid/operators/fill_diagonal_op.cu | 122 ++++++++++ paddle/fluid/operators/fill_diagonal_op.h | 25 ++ .../unittests/test_tensor_fill_diagonal_.py | 173 ++++++++++++++ python/paddle/tensor/manipulation.py | 49 ++++ 5 files changed, 586 insertions(+) create mode 100644 paddle/fluid/operators/fill_diagonal_op.cc create mode 100644 paddle/fluid/operators/fill_diagonal_op.cu create mode 100644 paddle/fluid/operators/fill_diagonal_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py diff --git a/paddle/fluid/operators/fill_diagonal_op.cc b/paddle/fluid/operators/fill_diagonal_op.cc new file mode 100644 index 00000000000..db55c3e9969 --- /dev/null +++ b/paddle/fluid/operators/fill_diagonal_op.cc @@ -0,0 +1,217 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_diagonal_op.h" + +namespace paddle { +namespace operators { + +int64_t CalStride(framework::DDim dim) { + int rank = dim.size(); + int64_t dimsum = 1; + int64_t strides = 0; + for (int i = rank - 1; i >= 0; i--) { + strides += dimsum; + dimsum *= dim[i]; + } + return strides; +} + +class FillIDiagonalOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddComment(R"DOC(Fill replace operator + Fill the diagonal of an tensor with 'value'. + )DOC"); + AddInput("X", "(Tensor) The input tensor."); + AddOutput("Out", + "Tensor, the output tensor, with the same shape and data type " + "as input(x)"); + AddAttr( + "value", + "The float values of tensor, whose dim is one, and no need of grad") + .SetDefault(0); + AddAttr("wrap", + "the diagonal 'wrapped' after N columns for tall matrices") + .SetDefault(false); + AddAttr("offset", + "offset of diagonal, zero means no offset, positive means " + "offset to up-right corner; negtive means offset to " + "bottom-left corner") + .SetDefault(0); + } +}; + +class FillIDiagonalOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillIDiagonal"); + OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillIDiagonal"); + auto x_dims = context->GetInputDim("X"); + context->SetOutputDim("Out", x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class FillIDiagonalOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto var_type = ctx->GetInputType("X", 0); + auto data_type = ctx->GetInputDataType("X", 0); + ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS); + } +}; + +template +class FillIDiagonalKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto fill_val = ctx.template Attr("value"); + auto *out = ctx.Output("Out"); + auto offset = ctx.Attr("offset"); + auto wrap = ctx.Attr("wrap"); + + auto *xin = ctx.Input("X"); + + T temp_var = static_cast(fill_val); + + T *out_data = out->mutable_data(ctx.GetPlace()); + framework::TensorCopy(*xin, ctx.GetPlace(), out); + + auto out_dims = out->dims(); + auto strides = CalStride(out_dims); + auto size = out->numel(); + + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (!wrap) { + size = std::min(size, out_dims[1] * out_dims[1]); + } + + for (int64_t i = offset; i < size; i += strides) { + out_data[i] = temp_var; + } + } +}; + +class FillIDiagonalGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "mul"); + auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + // Note: don't get data type from ctx.Input("Input"); + auto dtype = + ctx.Input(framework::GradVarName("Out"))->type(); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class FillIDiagonalGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("fill_diagonal_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +template +class FillIDiagonalGradKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dout = ctx.Input(framework::GradVarName("Out")); + + auto offset = ctx.Attr("offset"); + auto wrap = ctx.Attr("wrap"); + + if (dx) { + auto *data = dx->mutable_data(ctx.GetPlace()); + framework::TensorCopy(*dout, ctx.GetPlace(), dx); + + auto dx_dims = dx->dims(); + auto strides = CalStride(dx_dims); + auto size = dx->numel(); + auto wrapsize = std::min(size, dx_dims[1] * dx_dims[1]); + + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (wrap) { + wrapsize = size; + } + + for (int64_t i = offset; i < wrapsize; i += strides) { + data[i] = T(0); + } + } + } +}; + +DECLARE_INPLACE_OP_INFERER(FillIDiagonalOpInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fill_diagonal, ops::FillIDiagonalOp, + ops::FillIDiagonalOpMaker, + ops::FillIDiagonalOpVarTypeInference, + ops::FillIDiagonalGradOpMaker, + ops::FillIDiagonalGradOpMaker, + ops::FillIDiagonalOpInplaceInferer); + +REGISTER_OPERATOR(fill_diagonal_grad, ops::FillIDiagonalGradOp, + ops::FillIDiagonalGradOpInplaceInferer); + +REGISTER_OP_CPU_KERNEL(fill_diagonal, ops::FillIDiagonalKernel, + ops::FillIDiagonalKernel, + ops::FillIDiagonalKernel, + ops::FillIDiagonalKernel, + ops::FillIDiagonalKernel, + ops::FillIDiagonalKernel); + +REGISTER_OP_CPU_KERNEL(fill_diagonal_grad, ops::FillIDiagonalGradKernel, + ops::FillIDiagonalGradKernel, + ops::FillIDiagonalGradKernel, + ops::FillIDiagonalGradKernel, + ops::FillIDiagonalGradKernel, + ops::FillIDiagonalGradKernel); diff --git a/paddle/fluid/operators/fill_diagonal_op.cu b/paddle/fluid/operators/fill_diagonal_op.cu new file mode 100644 index 00000000000..5047059fb36 --- /dev/null +++ b/paddle/fluid/operators/fill_diagonal_op.cu @@ -0,0 +1,122 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_diagonal_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +template +__global__ void fill_constant_kernel(const int64_t featuresize, T* in_data, + int64_t strides, int offset, T fillvar) { + for (int64_t idx = blockIdx.x * featuresize + threadIdx.x; + idx * strides + offset < (blockIdx.x + 1) * featuresize; + idx += blockDim.x) { + in_data[idx * strides + offset] = fillvar; + } +} + +template +class FillIDiagonalCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto* out = ctx.Output("Out"); + auto offset = ctx.Attr("offset"); + auto wrap = ctx.Attr("wrap"); + + auto* xin = ctx.Input("X"); + framework::TensorCopy(*xin, ctx.GetPlace(), out); + + T* out_data = out->mutable_data(ctx.GetPlace()); + auto fill_val = static_cast(ctx.template Attr("value")); + T temp_var = static_cast(fill_val); + + auto size = out->numel(); + auto out_dims = out->dims(); + auto strides = CalStride(out_dims); + + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (!wrap) { + size = std::min(size, out_dims[1] * out_dims[1]); + } + + int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim); + fill_constant_kernel<<<1, kBlockDim, 0>>>(size, out_data, strides, + offset, temp_var); + } +}; + +template +class FillIDiagonalGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* in_data = dx->mutable_data(ctx.GetPlace()); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto offset = ctx.Attr("offset"); + auto wrap = ctx.Attr("wrap"); + + framework::TensorCopy(*dout, ctx.GetPlace(), dx); + + auto size = dx->numel(); + auto out_dims = dx->dims(); + auto strides = CalStride(out_dims); + + auto wrapsize = std::min(size, out_dims[1] * out_dims[1]); + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (wrap) { + wrapsize = size; + } + + int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim); + fill_constant_kernel<<<1, kBlockDim, 0>>>(wrapsize, in_data, strides, + offset, T(0)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(fill_diagonal, ops::FillIDiagonalCUDAKernel, + ops::FillIDiagonalCUDAKernel, + ops::FillIDiagonalCUDAKernel, + ops::FillIDiagonalCUDAKernel, + ops::FillIDiagonalCUDAKernel, + ops::FillIDiagonalCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(fill_diagonal_grad, + ops::FillIDiagonalGradCUDAKernel, + ops::FillIDiagonalGradCUDAKernel, + ops::FillIDiagonalGradCUDAKernel, + ops::FillIDiagonalGradCUDAKernel, + ops::FillIDiagonalGradCUDAKernel, + ops::FillIDiagonalGradCUDAKernel); diff --git a/paddle/fluid/operators/fill_diagonal_op.h b/paddle/fluid/operators/fill_diagonal_op.h new file mode 100644 index 00000000000..4531503e30d --- /dev/null +++ b/paddle/fluid/operators/fill_diagonal_op.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +int64_t CalStride(framework::DDim dim); + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py new file mode 100644 index 00000000000..41a8a9750cb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_.py @@ -0,0 +1,173 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import numpy as np +import six +import paddle + + +class TensorFillDiagonal_Test(unittest.TestCase): + def test_dim2_normal(self): + expected_np = np.array( + [[1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') + expected_grad = np.array( + [[0, 1, 1], [1, 0, 1], [1, 1, 0]]).astype('float32') + + typelist = ['float32', 'float64', 'int32', 'int64'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in typelist: + x = paddle.ones((3, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_(1, offset=0, wrap=True) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_bool(self): + expected_np = np.array( + [[False, True, True], [True, False, True], [True, True, False]]) + + typelist = ['bool'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in typelist: + x = paddle.ones((3, 3), dtype=dtype) + x.stop_gradient = True + x.fill_diagonal_(0, offset=0, wrap=True) + + self.assertEqual((x.numpy() == expected_np).all(), True) + + def test_dim2_unnormal_wrap(self): + expected_np = np.array([[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], + [1, 2, 2], [2, 1, 2], + [2, 2, 1]]).astype('float32') + expected_grad = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1], + [0, 1, 1], [1, 0, 1], + [1, 1, 0]]).astype('float32') + + typelist = ['float32', 'float64', 'int32', 'int64'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in typelist: + x = paddle.ones((7, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_(1, offset=0, wrap=True) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim2_unnormal_unwrap(self): + expected_np = np.array([[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], + [2, 2, 2], [2, 2, 2], + [2, 2, 2]]).astype('float32') + expected_grad = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1], + [1, 1, 1], [1, 1, 1], + [1, 1, 1]]).astype('float32') + + typelist = ['float32', 'float64', 'int32', 'int64'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in typelist: + x = paddle.ones((7, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_(1, offset=0, wrap=False) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim_larger2_normal(self): + expected_np = np.array([[[1, 2, 2], [2, 2, 2], [2, 2, 2]], [[2, 2, 2], [ + 2, 1, 2 + ], [2, 2, 2]], [[2, 2, 2], [2, 2, 2], [2, 2, 1]]]).astype('float32') + expected_grad = np.array( + [[[0, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 0, 1], + [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 0]]]).astype('float32') + + typelist = ['float32', 'float64', 'int32', 'int64'] + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for idx, p in enumerate(places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in typelist: + x = paddle.ones((3, 3, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_(1, offset=0, wrap=True) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4b84401aa09..1f0c0ba24d9 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -37,6 +37,55 @@ from paddle import _C_ops __all__ = [] +@dygraph_only +def fill_diagonal_(x, value, offset=0, wrap=False, name=None): + """ + **Notes**: + **This API is ONLY available in Dygraph mode** + This function fill the value into the x Tensor's diagonal inplace. + Args: + x(Tensor): ``x`` is the original Tensor + value(Scale): ``value`` is the value to filled in x + offset(int,optional): the offset to the main diagonal. Default: 0 (main diagonal). + wrap(bool,optional): the diagonal 'wrapped' after N columns for tall matrices. + name(str,optional): Name for the operation (optional, default is None) + Returns: + Tensor: Tensor with diagonal filled with value. + Returns type: + dtype is same as x Tensor + Examples: + .. code-block:: python + import paddle + x = paddle.ones((4, 3)) * 2 + x.fill_diagonal_(1.0) + print(x.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] + """ + helper = LayerHelper("fill_diagonal_", **locals()) + check_type(x, 'X', (Variable), 'fill_diagonal_') + dtype = helper.input_dtype('x') + check_dtype(dtype, 'X', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'fill_diagonal_') + check_type(value, 'value', (bool, int, float), 'fill_diagonal_') + check_type(wrap, 'wrap', (bool), 'fill_diagonal_') + + inshape = x.shape + inshapeset = set(inshape) + assert len(inshape) >= 2, ('Tensor dims should >= 2 in fill_diagonal_ API') + if len(inshape) > 2: + assert len(inshapeset) == 1, ( + 'Tensor dims should be equal while input dims > 2 in fill_diagonal_ API' + ) + if len(inshape) == 2: + return core.ops.fill_diagonal_(x, 'value', value, 'offset', offset, + 'wrap', wrap) + return core.ops.fill_diagonal_(x, 'value', value, 'offset', offset, 'wrap', + True) + + +setattr(core.VarBase, 'fill_diagonal_', fill_diagonal_) + + @dygraph_only def tolist(x): """ -- GitLab