diff --git a/paddle/fluid/operators/fill_diagonal_tensor_op.cc b/paddle/fluid/operators/fill_diagonal_tensor_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0c348693202e1c06310eeedefa3a820cd466ef95 --- /dev/null +++ b/paddle/fluid/operators/fill_diagonal_tensor_op.cc @@ -0,0 +1,292 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_diagonal_tensor_op.h" + +namespace paddle { +namespace operators { + +// calculate the offset\new_dims\(strides of dim1/dim2)\matoffset +void CalMatDims(framework::DDim out_dims, int dim1, int dim2, int64_t *offset, + int64_t *new_dims, int64_t *strides, int64_t *matoffset) { + int64_t dimprod = 1, batchdim = 1; + int rank = out_dims.size(); + int matoffidx = 0; + for (int i = rank - 1; i >= 0; i--) { + if (i == dim2) { + strides[0] = dimprod; + } else if (i == dim1) { + strides[1] = dimprod; + } else { + batchdim *= out_dims[i]; + // matoffset calculate the offset position of the diagonal defined by dim1 + // and dim2 + // the first circle calculate the final free dimension + // and then calculate the front free dim one by one + if (matoffidx == 0) { + for (int64_t j = 0; j < out_dims[i]; j++) { + matoffset[matoffidx] = dimprod * j; + matoffidx++; + } + } else { + auto size = matoffidx; + for (int64_t j = 1; j < out_dims[i]; j++) { + for (int64_t k = 0; k < size; k++) { + matoffset[matoffidx] = matoffset[k] + dimprod * j; + matoffidx++; + } + } + } + } + dimprod *= out_dims[i]; + } + + auto diagdim = dim1; + if (*offset >= 0) { + diagdim = std::min(out_dims[dim1], out_dims[dim2] - *offset); + *offset *= strides[0]; + } else { + diagdim = std::min(out_dims[dim1] + *offset, out_dims[dim2]); + *offset *= -strides[1]; + } + new_dims[0] = batchdim; + new_dims[1] = diagdim; + return; +} + +class FillDiagonalTensorOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddComment(R"DOC(Fill replace operator + Fill the diagonal of an tensor with `Y` Tensor. + )DOC"); + AddInput("X", "(Tensor) The input tensor."); + AddInput("Y", "(Tensor) The input tensor to fill in."); + AddOutput("Out", + "Tensor, the output tensor, with the same shape and data type " + "as input(x)"); + AddAttr("dim1", "the first dim to figure out the diagonal") + .SetDefault(0); + AddAttr("dim2", "the second dim to figure out the diagonal") + .SetDefault(1); + AddAttr("offset", + "offset of diagonal, zero means no offset, positive means " + "offset to up-right corner; negtive means offset to " + "bottom-left corner") + .SetDefault(0); + } +}; + +class FillDiagonalTensorOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillDiagonalTensor"); + OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", + "FillDiagonalTensor"); + auto x_dims = context->GetInputDim("X"); + context->SetOutputDim("Out", x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class FillDiagonalTensorOpVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto var_type = ctx->GetInputType("X", 0); + auto data_type = ctx->GetInputDataType("X", 0); + ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS); + } +}; + +template +class FillDiagonalTensorKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto *out = ctx.Output("Out"); + auto *srctensor = ctx.Input("Y"); + auto dim1 = ctx.Attr("dim1"); + auto dim2 = ctx.Attr("dim2"); + auto offset = ctx.Attr("offset"); + auto *xin = ctx.Input("X"); + + T *out_data = out->mutable_data(ctx.GetPlace()); + const T *fill_data = srctensor->data(); + + framework::TensorCopy(*xin, ctx.GetPlace(), out); + auto out_dims = out->dims(); + auto matdims = srctensor->dims(); + auto fill_dims = flatten_to_2d(matdims, matdims.size() - 1); + + int64_t new_dims[2], strides[2]; + std::vector matdim; + matdim.resize(fill_dims[0]); + CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim.data()); + PADDLE_ENFORCE_EQ( + new_dims[0], fill_dims[0], + platform::errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], new_dims[1], + fill_dims[0], fill_dims[1])); + PADDLE_ENFORCE_EQ( + new_dims[1], fill_dims[1], + platform::errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], new_dims[1], + fill_dims[0], fill_dims[1])); + + auto size = out->numel(); + for (int64_t i = 0; i < fill_dims[0]; i += 1) { + auto sumoff = matdim[i] + offset; + for (int64_t j = 0; j < fill_dims[1]; j += 1) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + out_data[fill_index] = fill_data[i * fill_dims[1] + j]; + } + } + } + } +}; + +class FillDiagonalTensorGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "mul"); + auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + // Note: don't get data type from ctx.Input("Input"); + auto dtype = + ctx.Input(framework::GradVarName("Out"))->type(); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class FillDiagonalTensorGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("fill_diagonal_tensor_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +template +class FillDiagonalTensorGradKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dout = ctx.Input(framework::GradVarName("Out")); + + auto dim1 = ctx.Attr("dim1"); + auto dim2 = ctx.Attr("dim2"); + auto offset = ctx.Attr("offset"); + auto matrows = 1; + + if (dx) { + auto *data = dx->mutable_data(ctx.GetPlace()); + + auto dx_dims = dx->dims(); + for (int i = 0; i < dx_dims.size(); i++) { + if (i != dim1 && i != dim2) { + matrows *= dx_dims[i]; + } + } + + int64_t new_dims[2], strides[2]; + std::vector matdim; + matdim.resize(matrows); + CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, + matdim.data()); + + auto size = dx->numel(); + framework::TensorCopy(*dout, ctx.GetPlace(), dx); + + for (int64_t i = 0; i < new_dims[0]; i += 1) { + auto sumoff = matdim[i] + offset; + for (int64_t j = 0; j < new_dims[1]; j += 1) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + data[fill_index] = 0; + } + } + } + } + } +}; + +DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorOpInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorGradOpInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + fill_diagonal_tensor, ops::FillDiagonalTensorOp, + ops::FillDiagonalTensorOpMaker, ops::FillDiagonalTensorOpVarTypeInference, + ops::FillDiagonalTensorGradOpMaker, + ops::FillDiagonalTensorGradOpMaker, + ops::FillDiagonalTensorOpInplaceInferer); + +REGISTER_OPERATOR(fill_diagonal_tensor_grad, ops::FillDiagonalTensorGradOp, + ops::FillDiagonalTensorGradOpInplaceInferer); + +REGISTER_OP_CPU_KERNEL( + fill_diagonal_tensor, ops::FillDiagonalTensorKernel, + ops::FillDiagonalTensorKernel, + ops::FillDiagonalTensorKernel, ops::FillDiagonalTensorKernel, + ops::FillDiagonalTensorKernel, + ops::FillDiagonalTensorKernel, + ops::FillDiagonalTensorKernel, + ops::FillDiagonalTensorKernel>, + ops::FillDiagonalTensorKernel>, + ops::FillDiagonalTensorKernel); + +REGISTER_OP_CPU_KERNEL( + fill_diagonal_tensor_grad, ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel, + ops::FillDiagonalTensorGradKernel>, + ops::FillDiagonalTensorGradKernel>, + ops::FillDiagonalTensorGradKernel); diff --git a/paddle/fluid/operators/fill_diagonal_tensor_op.cu b/paddle/fluid/operators/fill_diagonal_tensor_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..834964079fd397a1f056cee6b947c557c7541179 --- /dev/null +++ b/paddle/fluid/operators/fill_diagonal_tensor_op.cu @@ -0,0 +1,207 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fill_diagonal_tensor_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +template +__global__ void fill_diagonal_tensor_kernel(int64_t size, T *out_data, + const T *fill_data, + int64_t *strides, int64_t *matdim, + int64_t offset, int64_t fill_dims0, + int64_t fill_dims1) { + int64_t i = blockIdx.x; + auto sumoff = matdim[i] + offset; + for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + out_data[fill_index] = fill_data[i * fill_dims1 + j]; + } + } +} + +template +__global__ void fill_grad_kernel(int64_t size, T *out_data, int64_t *strides, + int64_t *matdim, int64_t offset, + int64_t fill_dims0, int64_t fill_dims1) { + int64_t i = blockIdx.x; + auto sumoff = matdim[i] + offset; + for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + out_data[fill_index] = T(0); + } + } +} + +template +class FillDiagonalTensorCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto *out = ctx.Output("Out"); + auto *srctensor = ctx.Input("Y"); + auto dim1 = ctx.Attr("dim1"); + auto dim2 = ctx.Attr("dim2"); + auto offset = ctx.Attr("offset"); + + auto *xin = ctx.Input("X"); + framework::TensorCopy(*xin, ctx.GetPlace(), out); + + T *out_data = out->mutable_data(ctx.GetPlace()); + const T *fill_data = srctensor->data(); + + auto out_dims = out->dims(); + auto matdims = srctensor->dims(); + auto fill_dims = flatten_to_2d(matdims, matdims.size() - 1); + + int64_t new_dims[2]; + std::vector memory_block; + memory_block.resize(2 + fill_dims[0]); + int64_t *strides = &(memory_block[0]); + int64_t *matdim = &(memory_block[2]); + CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim); + PADDLE_ENFORCE_EQ( + new_dims[0], fill_dims[0], + platform::errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], new_dims[1], + fill_dims[0], fill_dims[1])); + PADDLE_ENFORCE_EQ( + new_dims[1], fill_dims[1], + platform::errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], new_dims[1], + fill_dims[0], fill_dims[1])); + + auto size = out->numel(); + + auto &dev_ctx = ctx.template device_context(); + auto stream = dev_ctx.stream(); + Tensor tensor_tmp; + int64_t *memory_block_cu = + tensor_tmp.mutable_data({2 + fill_dims[0]}, ctx.GetPlace()); + const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + memory::Copy(gpu_place, memory_block_cu, platform::CPUPlace(), + memory_block.data(), sizeof(int64_t) * (2 + fill_dims[0]), + stream); + + int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2]; + + auto kGridDim = new_dims[0]; + auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim); + fill_diagonal_tensor_kernel<<>>( + size, out_data, fill_data, strides_cu, matdim_cu, offset, fill_dims[0], + fill_dims[1]); + } +}; + +template +class FillDiagonalTensorGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dout = ctx.Input(framework::GradVarName("Out")); + + auto dim1 = ctx.Attr("dim1"); + auto dim2 = ctx.Attr("dim2"); + auto offset = ctx.Attr("offset"); + auto matrows = 1; + + if (dx) { + auto *data = dx->mutable_data(ctx.GetPlace()); + auto dx_dims = dx->dims(); + framework::TensorCopy(*dout, ctx.GetPlace(), dx); + + for (int i = 0; i < dx_dims.size(); i++) { + if (i != dim1 && i != dim2) { + matrows *= dx_dims[i]; + } + } + + int64_t new_dims[2]; + std::vector memory_block; + memory_block.resize(2 + matrows); + int64_t *strides = &memory_block[0]; + int64_t *matdim = &memory_block[2]; + CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim); + + auto size = dx->numel(); + + auto &dev_ctx = + ctx.template device_context(); + auto stream = dev_ctx.stream(); + Tensor tensor_tmp; + int64_t *memory_block_cu = + tensor_tmp.mutable_data({2 + matrows}, ctx.GetPlace()); + const auto gpu_place = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + memory::Copy(gpu_place, memory_block_cu, platform::CPUPlace(), + memory_block.data(), sizeof(int64_t) * (2 + matrows), + stream); + + int64_t *strides_cu = &memory_block_cu[0], + *matdim_cu = &memory_block_cu[2]; + + auto kGridDim = new_dims[0]; + auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim); + fill_grad_kernel<<>>( + size, data, strides_cu, matdim_cu, offset, new_dims[0], new_dims[1]); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + fill_diagonal_tensor, ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel, + ops::FillDiagonalTensorCUDAKernel>, + ops::FillDiagonalTensorCUDAKernel>, + ops::FillDiagonalTensorCUDAKernel); + +REGISTER_OP_CUDA_KERNEL( + fill_diagonal_tensor_grad, ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel, + ops::FillDiagonalTensorGradCUDAKernel>, + ops::FillDiagonalTensorGradCUDAKernel>, + ops::FillDiagonalTensorGradCUDAKernel); diff --git a/paddle/fluid/operators/fill_diagonal_tensor_op.h b/paddle/fluid/operators/fill_diagonal_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ebb980b66af85d3a3508f233b749bf5188560e3b --- /dev/null +++ b/paddle/fluid/operators/fill_diagonal_tensor_op.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +void CalMatDims(framework::DDim out_dims, int dim1, int dim2, int64_t *offset, + int64_t *new_dims, int64_t *strides, int64_t *matoffset); + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py b/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac7a9586cb425dfac16c36b926f1c0db17759a6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py @@ -0,0 +1,149 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.nn.functional as F +import unittest +import numpy as np +import six +import paddle +from op_test import OpTest +from paddle.fluid.layers import core + + +def fill_diagonal_ndarray(x, value, offset=0, dim1=0, dim2=1): + """Fill value into the diagonal of x that offset is ${offset} and the coordinate system is (dim1, dim2).""" + strides = x.strides + shape = x.shape + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + assert 0 <= dim1 < dim2 <= 2 + assert len(x.shape) == 3 + + dim_sum = dim1 + dim2 + dim3 = len(x.shape) - dim_sum + if offset >= 0: + diagdim = min(shape[dim1], shape[dim2] - offset) + diagonal = np.lib.stride_tricks.as_strided( + x[:, offset:] if dim_sum == 1 else x[:, :, offset:], + shape=(shape[dim3], diagdim), + strides=(strides[dim3], strides[dim1] + strides[dim2])) + else: + diagdim = min(shape[dim2], shape[dim1] + offset) + diagonal = np.lib.stride_tricks.as_strided( + x[-offset:, :] if dim_sum in [1, 2] else x[:, -offset:], + shape=(shape[dim3], diagdim), + strides=(strides[dim3], strides[dim1] + strides[dim2])) + + diagonal[...] = value + return x + + +def fill_gt(x, y, offset, dim1, dim2): + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + xshape = x.shape + yshape = y.shape + if len(xshape) != 3: + perm_list = [] + unperm_list = [0] * len(xshape) + idx = 0 + + for i in range(len(xshape)): + if i != dim1 and i != dim2: + perm_list.append(i) + unperm_list[i] = idx + idx += 1 + perm_list += [dim1, dim2] + unperm_list[dim1] = idx + unperm_list[dim2] = idx + 1 + + x = np.transpose(x, perm_list) + y = y.reshape(-1, yshape[-1]) + nxshape = x.shape + x = x.reshape((-1, xshape[dim1], xshape[dim2])) + out = fill_diagonal_ndarray(x, y, offset, 1, 2) + + if len(xshape) != 3: + out = out.reshape(nxshape) + out = np.transpose(out, unperm_list) + return out + + +class TensorFillDiagTensor_Test(OpTest): + def setUp(self): + self.op_type = "fill_diagonal_tensor" + self.init_kernel_type() + x = np.random.random((10, 10)).astype(self.dtype) + y = np.random.random((10, )).astype(self.dtype) + dim1 = 0 + dim2 = 1 + offset = 0 + out = fill_gt(x, y, offset, dim1, dim2) + + self.inputs = {"X": x, "Y": y} + self.outputs = {'Out': out} + self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset} + + def init_kernel_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test): + def setUp(self): + self.op_type = "fill_diagonal_tensor" + self.init_kernel_type() + x = np.random.random((2, 20, 25)).astype(self.dtype) + y = np.random.random((2, 20)).astype(self.dtype) + dim1 = 2 + dim2 = 1 + offset = -3 + out = fill_gt(x, y, offset, dim1, dim2) + + self.inputs = {"X": x, "Y": y} + self.outputs = {'Out': out} + self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset} + + def init_kernel_type(self): + self.dtype = np.float32 + + +class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test): + def setUp(self): + self.op_type = "fill_diagonal_tensor" + self.init_kernel_type() + x = np.random.random((2, 20, 20, 3)).astype(self.dtype) + y = np.random.random((2, 3, 18)).astype(self.dtype) + dim1 = 1 + dim2 = 2 + offset = 2 + out = fill_gt(x, y, offset, dim1, dim2) + + self.inputs = {"X": x, "Y": y} + self.outputs = {'Out': out} + self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset} + + def init_kernel_type(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..47316809189b7cecc5ad66fa6154dcba1eec6acb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.nn.functional as F +import unittest +import numpy as np +import six +import paddle + + +class TensorFillDiagTensor_Test(unittest.TestCase): + def setUp(self): + self.typelist = ['float32', 'float64', 'int32', 'int64'] + self.places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dim2(self): + expected_np = np.array( + [[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2]]).astype('float32') + expected_grad = np.array( + [[0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.ones((3, ), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((4, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + ny = y.fill_diagonal_tensor(v, offset=0, dim1=0, dim2=1) + loss = ny.sum() + loss.backward() + + self.assertEqual( + (ny.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim2_offset_1(self): + expected_np = np.array( + [[2, 2, 2], [1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') + expected_grad = np.array( + [[1, 1, 1], [0, 1, 1], [1, 0, 1], [1, 1, 0]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.ones((3, ), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((4, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + ny = y.fill_diagonal_tensor(v, offset=-1, dim1=0, dim2=1) + loss = ny.sum() + loss.backward() + + self.assertEqual( + (ny.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim2_offset1(self): + expected_np = np.array( + [[2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32') + expected_grad = np.array( + [[1, 0, 1], [1, 1, 0], [1, 1, 1], [1, 1, 1]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.ones((2, ), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((4, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + ny = y.fill_diagonal_tensor(v, offset=1, dim1=0, dim2=1) + loss = ny.sum() + loss.backward() + + self.assertEqual( + (ny.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim4(self): + expected_np = np.array( + [[[[0, 3], [2, 2], [2, 2]], [[2, 2], [1, 4], [2, 2]], + [[2, 2], [2, 2], [2, 5]], [[2, 2], [2, 2], [2, 2]]], + [[[6, 9], [2, 2], [2, 2]], [[2, 2], [7, 10], [2, 2]], + [[2, 2], [2, 2], [8, 11]], + [[2, 2], [2, 2], [2, 2]]]]).astype('float32') + expected_grad = np.array( + [[[[0, 0], [1, 1], [1, 1]], [[1, 1], [0, 0], [1, 1]], + [[1, 1], [1, 1], [0, 0]], [[1, 1], [1, 1], [1, 1]]], + [[[0, 0], [1, 1], [1, 1]], [[1, 1], [0, 0], [1, 1]], + [[1, 1], [1, 1], [0, 0]], + [[1, 1], [1, 1], [1, 1]]]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.to_tensor( + np.arange(12).reshape(2, 2, 3), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((2, 4, 3, 2), dtype=dtype) + x.stop_gradient = False + y = x * 2 + ny = y.fill_diagonal_tensor(v, offset=0, dim1=1, dim2=2) + loss = ny.sum() + loss.backward() + + self.assertEqual( + (ny.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_largedim(self): + if len(self.places) > 1: + bsdim = 1024 + fsdim = 128 + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.arange( + bsdim * fsdim, dtype=dtype).reshape((bsdim, fsdim)) + y = paddle.ones((bsdim, fsdim, fsdim), dtype=dtype) + y.stop_gradient = False + y = y * 2 + ny = y.fill_diagonal_tensor(v, offset=0, dim1=1, dim2=2) + loss = ny.sum() + loss.backward() + + expected_pred = v - 2 + expected_pred = F.diag_embed(expected_pred) + 2 + expected_grad = paddle.ones(v.shape, dtype=dtype) - 2 + expected_grad = F.diag_embed(expected_grad) + 1 + + self.assertEqual((ny == expected_pred).all(), True) + self.assertEqual((y.grad == expected_grad).all(), True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py new file mode 100644 index 0000000000000000000000000000000000000000..2f37ccf219eb08aa3e8ae1fd9853801a01a893be --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor_.py @@ -0,0 +1,173 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.nn.functional as F +import unittest +import numpy as np +import six +import paddle + + +class TensorFillDiagTensor_Test(unittest.TestCase): + def setUp(self): + self.typelist = ['float32', 'float64', 'int32', 'int64'] + self.places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dim2(self): + expected_np = np.array( + [[1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2]]).astype('float32') + expected_grad = np.array( + [[0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.ones((3, ), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((4, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_tensor_(v, offset=0, dim1=0, dim2=1) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim2_offset_1(self): + expected_np = np.array( + [[2, 2, 2], [1, 2, 2], [2, 1, 2], [2, 2, 1]]).astype('float32') + expected_grad = np.array( + [[1, 1, 1], [0, 1, 1], [1, 0, 1], [1, 1, 0]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.ones((3, ), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((4, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_tensor_(v, offset=-1, dim1=0, dim2=1) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim2_offset1(self): + expected_np = np.array( + [[2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 2]]).astype('float32') + expected_grad = np.array( + [[1, 0, 1], [1, 1, 0], [1, 1, 1], [1, 1, 1]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.ones((2, ), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((4, 3), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_tensor_(v, offset=1, dim1=0, dim2=1) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_dim4(self): + expected_np = np.array( + [[[[0, 3], [2, 2], [2, 2]], [[2, 2], [1, 4], [2, 2]], + [[2, 2], [2, 2], [2, 5]], [[2, 2], [2, 2], [2, 2]]], + [[[6, 9], [2, 2], [2, 2]], [[2, 2], [7, 10], [2, 2]], + [[2, 2], [2, 2], [8, 11]], + [[2, 2], [2, 2], [2, 2]]]]).astype('float32') + expected_grad = np.array( + [[[[0, 0], [1, 1], [1, 1]], [[1, 1], [0, 0], [1, 1]], + [[1, 1], [1, 1], [0, 0]], [[1, 1], [1, 1], [1, 1]]], + [[[0, 0], [1, 1], [1, 1]], [[1, 1], [0, 0], [1, 1]], + [[1, 1], [1, 1], [0, 0]], + [[1, 1], [1, 1], [1, 1]]]]).astype('float32') + + for idx, p in enumerate(self.places): + if idx == 0: + paddle.set_device('cpu') + else: + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.to_tensor( + np.arange(12).reshape(2, 2, 3), dtype=dtype) + var = (np.random.random() + 1) + x = paddle.ones((2, 4, 3, 2), dtype=dtype) + x.stop_gradient = False + y = x * 2 + y.fill_diagonal_tensor_(v, offset=0, dim1=1, dim2=2) + loss = y.sum() + loss.backward() + + self.assertEqual( + (y.numpy().astype('float32') == expected_np).all(), True) + self.assertEqual( + (y.grad.numpy().astype('float32') == expected_grad).all(), + True) + + def test_largedim(self): + #large dim only test on gpu because the cpu version is too slow for ci test, and the memory is limited + if len(self.places) > 1: + bsdim = 1024 + fsdim = 128 + paddle.set_device('gpu') + for dtype in self.typelist: + v = paddle.arange( + bsdim * fsdim, dtype=dtype).reshape((bsdim, fsdim)) + y = paddle.ones((bsdim, fsdim, fsdim), dtype=dtype) + y.stop_gradient = False + y = y * 2 + y.fill_diagonal_tensor_(v, offset=0, dim1=1, dim2=2) + loss = y.sum() + loss.backward() + + expected_pred = v - 2 + expected_pred = F.diag_embed(expected_pred) + 2 + expected_grad = paddle.ones(v.shape, dtype=dtype) - 2 + expected_grad = F.diag_embed(expected_grad) + 1 + + self.assertEqual((y == expected_pred).all(), True) + self.assertEqual((y.grad == expected_grad).all(), True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7218254b34a5408f5c643a7ad56dafea2cf67784..5f36917996a5d5a91f6b21b05803c327db5e6c9e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -86,6 +86,112 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None): setattr(core.VarBase, 'fill_diagonal_', fill_diagonal_) +def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False): + inshape = x.shape + assert dim1 < len(inshape) and dim1 >= -len(inshape), ( + 'dim1 should between [-rank,rank) in fill_diagonal_tensor_') + assert dim2 < len(inshape) and dim2 >= -len(inshape), ( + 'dim2 should between [-rank,rank) in fill_diagonal_tensor_') + assert len(inshape) >= 2, ( + 'Tensor dims should >= 2 in fill_diagonal_tensor_') + dim1 %= len(inshape) + dim2 %= len(inshape) + + predshape = [] + for i in range(len(inshape)): + if i != dim1 and i != dim2: + predshape.append(inshape[i]) + diaglen = min( + min(inshape[dim1], inshape[dim1] + offset), + min(inshape[dim2], inshape[dim2] - offset)) + predshape.append(diaglen) + assert tuple(predshape) == tuple(y.shape), ( + "the y shape should be {}".format(predshape)) + if len(y.shape) == 1: + y = y.reshape([1, -1]) + + if inplace: + return core.ops.fill_diagonal_tensor_(x, y, 'dim1', dim1, 'dim2', dim2, + 'offset', offset) + return core.ops.fill_diagonal_tensor(x, y, 'dim1', dim1, 'dim2', dim2, + 'offset', offset) + + +def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None): + """ + **Notes**: + **This API is ONLY available in Dygraph mode** + + This function fill the source Tensor y into the x Tensor's diagonal inplace. + + Args: + x(Tensor): ``x`` is the original Tensor + y(Tensor): ``y`` is the Tensor to filled in x + dim1(int,optional): first dimension with respect to which to fill diagonal. Default: 0. + dim2(int,optional): second dimension with respect to which to fill diagonal. Default: 1. + offset(int,optional): the offset to the main diagonal. Default: 0 (main diagonal). + name(str,optional): Name for the operation (optional, default is None) + + Returns: + Tensor: Tensor with diagonal filled with y. + + Returns type: + list: dtype is same as x Tensor + + Examples: + .. code-block:: python + + import paddle + + x = paddle.ones((4, 3)) * 2 + y = paddle.ones((3,)) + x.fill_diagonal_tensor_(y) + print(x.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] + + """ + return _fill_diagonal_tensor_impl( + x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=True) + + +setattr(core.VarBase, 'fill_diagonal_tensor_', fill_diagonal_tensor_) + + +def fill_diagonal_tensor(x, y, offset=0, dim1=0, dim2=1, name=None): + """ + This function fill the source Tensor y into the x Tensor's diagonal. + + Args: + x(Tensor): ``x`` is the original Tensor + y(Tensor): ``y`` is the Tensor to filled in x + dim1(int,optional): first dimension with respect to which to fill diagonal. Default: 0. + dim2(int,optional): second dimension with respect to which to fill diagonal. Default: 1. + offset(int,optional): the offset to the main diagonal. Default: 0 (main diagonal). + name(str,optional): Name for the operation (optional, default is None) + + Returns: + Tensor: Tensor with diagonal filled with y. + + Returns type: + list: dtype is same as x Tensor + + Examples: + .. code-block:: python + + import paddle + + x = paddle.ones((4, 3)) * 2 + y = paddle.ones((3,)) + nx = x.fill_diagonal_tensor(y) + print(nx.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] + + """ + return _fill_diagonal_tensor_impl( + x, y, offset=offset, dim1=dim1, dim2=dim2, inplace=False) + + +setattr(core.VarBase, 'fill_diagonal_tensor', fill_diagonal_tensor) + + @dygraph_only def tolist(x): """ diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 4442ff538cb0179ed3e050cd9832bf10e535a0c0..3609cfd183bf304555277977d79ef3e73ce9e087 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -721,5 +721,6 @@ STATIC_MODE_TESTING_LIST = [ 'test_marker_op', 'test_c_embedding_op', 'test_class_center_sample_op', + 'test_fill_diagonal_tensor_op', 'test_margin_cross_entropy_op', ]