From 9f1616a0ec71f4f1578f363be3c50734d29affe8 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Wed, 3 Aug 2022 19:05:24 +0800 Subject: [PATCH] Phi fill diagonal (#44453) * phi_fill_diagonal * remove old kernels * update * update attr args * refix * update --- paddle/fluid/operators/fill_diagonal_op.cc | 143 +++--------------- paddle/fluid/operators/fill_diagonal_op.cu | 133 ---------------- paddle/fluid/operators/fill_diagonal_op.h | 25 --- paddle/phi/api/yaml/legacy_api.yaml | 10 ++ paddle/phi/api/yaml/legacy_backward.yaml | 9 ++ paddle/phi/infermeta/backward.cc | 12 ++ paddle/phi/infermeta/backward.h | 3 + paddle/phi/infermeta/unary.cc | 11 ++ paddle/phi/infermeta/unary.h | 3 + .../kernels/cpu/fill_diagonal_grad_kernel.cc | 63 ++++++++ .../phi/kernels/cpu/fill_diagonal_kernel.cc | 67 ++++++++ .../phi/kernels/fill_diagonal_grad_kernel.h | 31 ++++ paddle/phi/kernels/fill_diagonal_kernel.h | 31 ++++ .../kernels/gpu/fill_diagonal_grad_kernel.cu | 88 +++++++++++ .../phi/kernels/gpu/fill_diagonal_kernel.cu | 90 +++++++++++ .../kernels/impl/fill_diagonal_kernel_impl.h | 32 ++++ paddle/phi/ops/compat/fill_diagonal_sig.cc | 37 +++++ python/paddle/tensor/manipulation.py | 6 + 18 files changed, 511 insertions(+), 283 deletions(-) delete mode 100644 paddle/fluid/operators/fill_diagonal_op.cu delete mode 100644 paddle/fluid/operators/fill_diagonal_op.h create mode 100644 paddle/phi/kernels/cpu/fill_diagonal_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/fill_diagonal_kernel.cc create mode 100644 paddle/phi/kernels/fill_diagonal_grad_kernel.h create mode 100644 paddle/phi/kernels/fill_diagonal_kernel.h create mode 100644 paddle/phi/kernels/gpu/fill_diagonal_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/fill_diagonal_kernel.cu create mode 100644 paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h create mode 100644 paddle/phi/ops/compat/fill_diagonal_sig.cc diff --git a/paddle/fluid/operators/fill_diagonal_op.cc b/paddle/fluid/operators/fill_diagonal_op.cc index 200fb34bf08..4bf9635ae45 100644 --- a/paddle/fluid/operators/fill_diagonal_op.cc +++ b/paddle/fluid/operators/fill_diagonal_op.cc @@ -12,22 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fill_diagonal_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { -int64_t CalStride(framework::DDim dim) { - int rank = dim.size(); - int64_t dimsum = 1; - int64_t strides = 0; - for (int i = rank - 1; i >= 0; i--) { - strides += dimsum; - dimsum *= dim[i]; - } - return strides; -} - class FillIDiagonalOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -57,13 +49,6 @@ class FillIDiagonalOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillIDiagonal"); - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillIDiagonal"); - auto x_dims = context->GetInputDim("X"); - context->SetOutputDim("Out", x_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -82,61 +67,10 @@ class FillIDiagonalOpVarTypeInference : public framework::VarTypeInference { } }; -template -class FillIDiagonalKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto fill_val = ctx.template Attr("value"); - auto *out = ctx.Output("Out"); - auto offset = ctx.Attr("offset"); - auto wrap = ctx.Attr("wrap"); - - auto *xin = ctx.Input("X"); - - T temp_var = static_cast(fill_val); - - T *out_data = out->mutable_data(ctx.GetPlace()); - framework::TensorCopy(*xin, ctx.GetPlace(), out); - - auto out_dims = out->dims(); - auto strides = CalStride(out_dims); - auto size = out->numel(); - - // The wrap mode supported only the dims equels to 2; In wrap mode, the - // value will be filled in cycles - if (!wrap) { - size = std::min(size, out_dims[1] * out_dims[1]); - } - - for (int64_t i = 0; i < size; i += strides) { - // to check if the new position with offset is still in the same line; - // this modify should not affect across lines. - // out_dims[1] is also work for tensor with dim>2, for which the dims must - // be the same number - if (i % out_dims[1] + offset >= 0 && - i % out_dims[1] + offset < out_dims[1]) { - out_data[i + offset] = temp_var; - } - } - } -}; - class FillIDiagonalGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "mul"); - auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // Note: don't get data type from ctx.Input("Input"); @@ -160,41 +94,6 @@ class FillIDiagonalGradOpMaker : public framework::SingleGradOpMaker { } }; -template -class FillIDiagonalGradKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dout = ctx.Input(framework::GradVarName("Out")); - - auto offset = ctx.Attr("offset"); - auto wrap = ctx.Attr("wrap"); - - if (dx) { - auto *data = dx->mutable_data(ctx.GetPlace()); - framework::TensorCopy(*dout, ctx.GetPlace(), dx); - - auto dx_dims = dx->dims(); - auto strides = CalStride(dx_dims); - auto size = dx->numel(); - auto wrapsize = std::min(size, dx_dims[1] * dx_dims[1]); - - // The wrap mode supported only the dims equels to 2; In wrap mode, the - // value will be filled in cycles - if (wrap) { - wrapsize = size; - } - - for (int64_t i = 0; i < wrapsize; i += strides) { - if (i % dx_dims[1] + offset >= 0 && - i % dx_dims[1] + offset < dx_dims[1]) { - data[i + offset] = T(0); - } - } - } - } -}; - DECLARE_INPLACE_OP_INFERER(FillIDiagonalOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer, {framework::GradVarName("Out"), @@ -204,30 +103,24 @@ DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer, } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal, + FillDiagonalShapeFunctor, + PD_INFER_META(phi::FillDiagonalInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal_grad, + FillDiagonalGradShapeFunctor, + PD_INFER_META(phi::FillDiagonalGradInferMeta)); + REGISTER_OPERATOR(fill_diagonal, ops::FillIDiagonalOp, - ops::FillIDiagonalOpMaker, - ops::FillIDiagonalOpVarTypeInference, ops::FillIDiagonalGradOpMaker, ops::FillIDiagonalGradOpMaker, - ops::FillIDiagonalOpInplaceInferer); + ops::FillIDiagonalOpMaker, + ops::FillIDiagonalOpInplaceInferer, + ops::FillIDiagonalOpVarTypeInference, + FillDiagonalShapeFunctor); REGISTER_OPERATOR(fill_diagonal_grad, ops::FillIDiagonalGradOp, - ops::FillIDiagonalGradOpInplaceInferer); - -REGISTER_OP_CPU_KERNEL(fill_diagonal, - ops::FillIDiagonalKernel, - ops::FillIDiagonalKernel, - ops::FillIDiagonalKernel, - ops::FillIDiagonalKernel, - ops::FillIDiagonalKernel, - ops::FillIDiagonalKernel); - -REGISTER_OP_CPU_KERNEL(fill_diagonal_grad, - ops::FillIDiagonalGradKernel, - ops::FillIDiagonalGradKernel, - ops::FillIDiagonalGradKernel, - ops::FillIDiagonalGradKernel, - ops::FillIDiagonalGradKernel, - ops::FillIDiagonalGradKernel); + ops::FillIDiagonalGradOpInplaceInferer, + FillDiagonalGradShapeFunctor); diff --git a/paddle/fluid/operators/fill_diagonal_op.cu b/paddle/fluid/operators/fill_diagonal_op.cu deleted file mode 100644 index 105b207636c..00000000000 --- a/paddle/fluid/operators/fill_diagonal_op.cu +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fill_diagonal_op.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void fill_constant_kernel(const int64_t featuresize, - T* in_data, - int64_t strides, - int offset, - T fillvar, - int dims) { - for (int64_t idx = blockIdx.x * featuresize + threadIdx.x; - idx * strides + offset < (blockIdx.x + 1) * featuresize; - idx += blockDim.x) { - // to check if the new position with offset is still in the same line; - // this modify should not affect across lines. - // out_dims[1] is also work for tensor with dim>2, for which the dims must - // be the same number - if ((idx * strides) % dims + offset < dims && - (idx * strides) % dims + offset >= 0) { - in_data[idx * strides + offset] = fillvar; - } - } -} - -template -class FillIDiagonalCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#ifdef __HIPCC__ - const int64_t kMaxBlockDim = 256; -#else - const int64_t kMaxBlockDim = 512; -#endif - auto* out = ctx.Output("Out"); - auto offset = ctx.Attr("offset"); - auto wrap = ctx.Attr("wrap"); - - auto* xin = ctx.Input("X"); - framework::TensorCopy(*xin, ctx.GetPlace(), out); - - T* out_data = out->mutable_data(ctx.GetPlace()); - auto fill_val = static_cast(ctx.template Attr("value")); - T temp_var = static_cast(fill_val); - - auto size = out->numel(); - auto out_dims = out->dims(); - auto strides = CalStride(out_dims); - - // The wrap mode supported only the dims equels to 2; In wrap mode, the - // value will be filled in cycles - if (!wrap) { - size = std::min(size, out_dims[1] * out_dims[1]); - } - - int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim); - fill_constant_kernel<<<1, kBlockDim, 0>>>( - size, out_data, strides, offset, temp_var, out_dims[1]); - } -}; - -template -class FillIDiagonalGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#ifdef __HIPCC__ - const int64_t kMaxBlockDim = 256; -#else - const int64_t kMaxBlockDim = 512; -#endif - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* in_data = dx->mutable_data(ctx.GetPlace()); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto offset = ctx.Attr("offset"); - auto wrap = ctx.Attr("wrap"); - - framework::TensorCopy(*dout, ctx.GetPlace(), dx); - - auto size = dx->numel(); - auto out_dims = dx->dims(); - auto strides = CalStride(out_dims); - - auto wrapsize = std::min(size, out_dims[1] * out_dims[1]); - // The wrap mode supported only the dims equels to 2; In wrap mode, the - // value will be filled in cycles - if (wrap) { - wrapsize = size; - } - - int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim); - fill_constant_kernel<<<1, kBlockDim, 0>>>( - wrapsize, in_data, strides, offset, T(0), out_dims[1]); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL(fill_diagonal, - ops::FillIDiagonalCUDAKernel, - ops::FillIDiagonalCUDAKernel, - ops::FillIDiagonalCUDAKernel, - ops::FillIDiagonalCUDAKernel, - ops::FillIDiagonalCUDAKernel, - ops::FillIDiagonalCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(fill_diagonal_grad, - ops::FillIDiagonalGradCUDAKernel, - ops::FillIDiagonalGradCUDAKernel, - ops::FillIDiagonalGradCUDAKernel, - ops::FillIDiagonalGradCUDAKernel, - ops::FillIDiagonalGradCUDAKernel, - ops::FillIDiagonalGradCUDAKernel); diff --git a/paddle/fluid/operators/fill_diagonal_op.h b/paddle/fluid/operators/fill_diagonal_op.h deleted file mode 100644 index 4531503e30d..00000000000 --- a/paddle/fluid/operators/fill_diagonal_op.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -int64_t CalStride(framework::DDim dim); - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 018f20a6087..b4ca7148a40 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -864,6 +864,16 @@ data_type : dtype backend : place +- api : fill_diagonal + args : (Tensor x, float value, int offset, bool wrap) + output : Tensor(out) + infer_meta : + func : FillDiagonalInferMeta + kernel : + func : fill_diagonal + inplace : (x -> out) + backward : fill_diagonal_grad + - api : flatten args : (Tensor x, int start_axis, int stop_axis) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 000f88979f3..8b43f7643c7 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -811,6 +811,15 @@ infer_meta : func : UnchangedInferMeta invoke : zeros_like(out_grad, DataType::UNDEFINED, {}) + +- backward_api : fill_diagonal_grad + forward : fill_diagonal (Tensor x, float value, int offset, bool wrap) -> Tensor(out) + args : (Tensor out_grad, float value, int offset, bool wrap) + output : Tensor(x_grad) + infer_meta : + func : FillDiagonalGradInferMeta + kernel : + func : fill_diagonal_grad inplace : (out_grad -> x_grad) - backward_api : flatten_grad diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 5395b4e23dc..87640cdddbf 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -285,6 +285,18 @@ void EigvalshGradInferMeta(const MetaTensor& out_v, } } +void FillDiagonalGradInferMeta(const MetaTensor& dout, + float value, + int offset, + bool wrap, + MetaTensor* dx) { + auto x_dims = dout.dims(); + if (dx) { + dx->set_dims(x_dims); + dx->set_dtype(dout.dtype()); + } +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index a0e79cfaf04..1ada2c80157 100755 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -137,6 +137,9 @@ void EigvalshGradInferMeta(const MetaTensor& out_v, bool is_test, MetaTensor* x_grad); +void FillDiagonalGradInferMeta( + const MetaTensor& dout, float value, int offset, bool wrap, MetaTensor* dx); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8cc7f75533c..74705c3759d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -855,6 +855,17 @@ void ExpandInferMeta(const MetaTensor& x, } } +void FillDiagonalInferMeta( + const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out) { + PADDLE_ENFORCE_NE( + out, + nullptr, + phi::errors::InvalidArgument("Tensor out should not be null if ")); + auto x_dims = x.dims(); + out->set_dims(x_dims); + out->set_dtype(x.dtype()); +} + void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ae88ecb40c3..bd35855a431 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -132,6 +132,9 @@ void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out); +void FillDiagonalInferMeta( + const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out); + void FlattenInferMeta(const MetaTensor& x, int start_axis, int stop_axis, diff --git a/paddle/phi/kernels/cpu/fill_diagonal_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_diagonal_grad_kernel.cc new file mode 100644 index 00000000000..1291a677bf9 --- /dev/null +++ b/paddle/phi/kernels/cpu/fill_diagonal_grad_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FillDiagonalGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float value, + int offset, + bool wrap, + DenseTensor* x_grad) { + if (x_grad) { + T* data = ctx.template Alloc(x_grad); + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + + auto dx_dims = x_grad->dims(); + auto strides = CalStride(dx_dims); + auto size = x_grad->numel(); + auto wrapsize = std::min(size, dx_dims[1] * dx_dims[1]); + + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (wrap) { + wrapsize = size; + } + + for (int64_t i = 0; i < wrapsize; i += strides) { + if (i % dx_dims[1] + offset >= 0 && + i % dx_dims[1] + offset < dx_dims[1]) { + data[i + offset] = T(0); + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal_grad, + CPU, + ALL_LAYOUT, + phi::FillDiagonalGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + bool) {} diff --git a/paddle/phi/kernels/cpu/fill_diagonal_kernel.cc b/paddle/phi/kernels/cpu/fill_diagonal_kernel.cc new file mode 100644 index 00000000000..232f2444cf4 --- /dev/null +++ b/paddle/phi/kernels/cpu/fill_diagonal_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { + +template +void FillDiagonalKernel(const Context& ctx, + const DenseTensor& x, + float value, + int offset, + bool wrap, + DenseTensor* out) { + T temp_var = static_cast(value); + + T* out_data = ctx.template Alloc(out); + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + + auto out_dims = out->dims(); + auto strides = CalStride(out_dims); + auto size = out->numel(); + + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (!wrap) { + size = std::min(size, out_dims[1] * out_dims[1]); + } + + for (int64_t i = 0; i < size; i += strides) { + // to check if the new position with offset is still in the same line; + // this modify should not affect across lines. + // out_dims[1] is also work for tensor with dim>2, for which the dims must + // be the same number + if (i % out_dims[1] + offset >= 0 && + i % out_dims[1] + offset < out_dims[1]) { + out_data[i + offset] = temp_var; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal, + CPU, + ALL_LAYOUT, + phi::FillDiagonalKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + bool) {} diff --git a/paddle/phi/kernels/fill_diagonal_grad_kernel.h b/paddle/phi/kernels/fill_diagonal_grad_kernel.h new file mode 100644 index 00000000000..23f2ae577c2 --- /dev/null +++ b/paddle/phi/kernels/fill_diagonal_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h" + +namespace phi { + +template +void FillDiagonalGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float value, + int offset, + bool wrap, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fill_diagonal_kernel.h b/paddle/phi/kernels/fill_diagonal_kernel.h new file mode 100644 index 00000000000..ecd3ffbe5cc --- /dev/null +++ b/paddle/phi/kernels/fill_diagonal_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h" + +namespace phi { + +template +void FillDiagonalKernel(const Context& ctx, + const DenseTensor& x, + float value, + int offset, + bool wrap, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/fill_diagonal_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_diagonal_grad_kernel.cu new file mode 100644 index 00000000000..8884dfae178 --- /dev/null +++ b/paddle/phi/kernels/gpu/fill_diagonal_grad_kernel.cu @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_grad_kernel.h" + +#include +#include + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void fill_constant_kernel(const int64_t featuresize, + T* in_data, + int64_t strides, + int offset, + T fillvar, + int dims) { + for (int64_t idx = blockIdx.x * featuresize + threadIdx.x; + idx * strides + offset < (blockIdx.x + 1) * featuresize; + idx += blockDim.x) { + // to check if the new position with offset is still in the same line; + // this modify should not affect across lines. + // out_dims[1] is also work for tensor with dim>2, for which the dims must + // be the same number + if ((idx * strides) % dims + offset < dims && + (idx * strides) % dims + offset >= 0) { + in_data[idx * strides + offset] = fillvar; + } + } +} + +template +void FillDiagonalGradKernel(const Context& ctx, + const DenseTensor& out_grad, + float value, + int offset, + bool wrap, + DenseTensor* x_grad) { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto* in_data = ctx.template Alloc(x_grad); + + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + + auto size = x_grad->numel(); + auto out_dims = x_grad->dims(); + auto strides = CalStride(out_dims); + + auto wrapsize = std::min(size, out_dims[1] * out_dims[1]); + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (wrap) { + wrapsize = size; + } + + int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim); + fill_constant_kernel<<<1, kBlockDim, 0>>>( + wrapsize, in_data, strides, offset, T(0), out_dims[1]); +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal_grad, + GPU, + ALL_LAYOUT, + phi::FillDiagonalGradKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + bool) {} diff --git a/paddle/phi/kernels/gpu/fill_diagonal_kernel.cu b/paddle/phi/kernels/gpu/fill_diagonal_kernel.cu new file mode 100644 index 00000000000..3116842002a --- /dev/null +++ b/paddle/phi/kernels/gpu/fill_diagonal_kernel.cu @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_kernel.h" + +#include +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void fill_constant_kernel(const int64_t featuresize, + T* in_data, + int64_t strides, + int offset, + T fillvar, + int dims) { + for (int64_t idx = blockIdx.x * featuresize + threadIdx.x; + idx * strides + offset < (blockIdx.x + 1) * featuresize; + idx += blockDim.x) { + // to check if the new position with offset is still in the same line; + // this modify should not affect across lines. + // out_dims[1] is also work for tensor with dim>2, for which the dims must + // be the same number + if ((idx * strides) % dims + offset < dims && + (idx * strides) % dims + offset >= 0) { + in_data[idx * strides + offset] = fillvar; + } + } +} + +template +void FillDiagonalKernel(const Context& ctx, + const DenseTensor& x, + float value, + int offset, + bool wrap, + DenseTensor* out) { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + + T* out_data = ctx.template Alloc(out); + auto fill_val = static_cast(value); + T temp_var = static_cast(fill_val); + + auto size = out->numel(); + auto out_dims = out->dims(); + auto strides = CalStride(out_dims); + + // The wrap mode supported only the dims equels to 2; In wrap mode, the + // value will be filled in cycles + if (!wrap) { + size = std::min(size, out_dims[1] * out_dims[1]); + } + + int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim); + fill_constant_kernel<<<1, kBlockDim, 0>>>( + size, out_data, strides, offset, temp_var, out_dims[1]); +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal, + GPU, + ALL_LAYOUT, + phi::FillDiagonalKernel, + float, + double, + int64_t, + int, + phi::dtype::float16, + bool) {} diff --git a/paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h b/paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h new file mode 100644 index 00000000000..65383176a0f --- /dev/null +++ b/paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +inline int64_t CalStride(phi::DDim dim) { + int rank = dim.size(); + int64_t dimsum = 1; + int64_t strides = 0; + for (int i = rank - 1; i >= 0; i--) { + strides += dimsum; + dimsum *= dim[i]; + } + return strides; +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/fill_diagonal_sig.cc b/paddle/phi/ops/compat/fill_diagonal_sig.cc new file mode 100644 index 00000000000..81a0faf6458 --- /dev/null +++ b/paddle/phi/ops/compat/fill_diagonal_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FillDiagonalOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "fill_diagonal", {"X"}, {"value", "offset", "wrap"}, {"Out"}); +} + +KernelSignature FillDiagonalGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fill_diagonal_grad", + {"Out@GRAD"}, + {"value", "offset", "wrap"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fill_diagonal, phi::FillDiagonalOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_grad, + phi::FillDiagonalGradOpArgumentMapping); diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9da7f76e702..ea7ef4ff9d7 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -835,6 +835,7 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None): x.fill_diagonal_(1.0) print(x.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]] """ + helper = LayerHelper("fill_diagonal_", **locals()) check_type(x, 'X', (Variable), 'fill_diagonal_') dtype = helper.input_dtype('x') @@ -851,6 +852,11 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None): assert len(inshapeset) == 1, ( 'Tensor dims should be equal while input dims > 2 in fill_diagonal_ API' ) + if in_dygraph_mode(): + if len(inshape) == 2: + return _C_ops.final_state_fill_diagonal_(x, value, offset, wrap) + return _C_ops.final_state_fill_diagonal_(x, value, offset, True) + if len(inshape) == 2: return _C_ops.fill_diagonal_(x, 'value', value, 'offset', offset, 'wrap', wrap) -- GitLab