diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 6b544f785bbbec4fcfa1468e2109d436cfa196b4..578827f56cbc031c50ba59c68f6bc89fe957ab4b 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -50,12 +50,11 @@ elseif(WITH_MLU) elseif(WITH_ASCEND_CL) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc) - detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu - prior_box_op_npu.cc) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_npu.cc) else() detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) - detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) + detection_library(prior_box_op SRCS prior_box_op.cc) # detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) endif() diff --git a/paddle/fluid/operators/detection/prior_box_op.cu b/paddle/fluid/operators/detection/prior_box_op.cu deleted file mode 100644 index 1cdf7691338294a2787f4aae11bc6bf2178f4e20..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detection/prior_box_op.cu +++ /dev/null @@ -1,195 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/detection/prior_box_op.h" - -namespace paddle { -namespace operators { - -template -__device__ inline T clip(T in) { - return min(max(in, 0.), 1.); -} - -template -__global__ void GenPriorBox(T* out, - const T* aspect_ratios, - const int height, - const int width, - const int im_height, - const int im_width, - const int as_num, - const T offset, - const T step_width, - const T step_height, - const T* min_sizes, - const T* max_sizes, - const int min_num, - bool is_clip, - bool min_max_aspect_ratios_order) { - int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num; - int box_num = height * width * num_priors; - CUDA_KERNEL_LOOP(i, box_num) { - int h = i / (num_priors * width); - int w = (i / num_priors) % width; - int p = i % num_priors; - int m = max_sizes ? p / (as_num + 1) : p / as_num; - T cx = (w + offset) * step_width; - T cy = (h + offset) * step_height; - T bw, bh; - T min_size = min_sizes[m]; - if (max_sizes) { - int s = p % (as_num + 1); - if (!min_max_aspect_ratios_order) { - if (s < as_num) { - T ar = aspect_ratios[s]; - bw = min_size * sqrt(ar) / 2.; - bh = min_size / sqrt(ar) / 2.; - } else { - T max_size = max_sizes[m]; - bw = sqrt(min_size * max_size) / 2.; - bh = bw; - } - } else { - if (s == 0) { - bw = bh = min_size / 2.; - } else if (s == 1) { - T max_size = max_sizes[m]; - bw = sqrt(min_size * max_size) / 2.; - bh = bw; - } else { - T ar = aspect_ratios[s - 1]; - bw = min_size * sqrt(ar) / 2.; - bh = min_size / sqrt(ar) / 2.; - } - } - } else { - int s = p % as_num; - T ar = aspect_ratios[s]; - bw = min_size * sqrt(ar) / 2.; - bh = min_size / sqrt(ar) / 2.; - } - T xmin = (cx - bw) / im_width; - T ymin = (cy - bh) / im_height; - T xmax = (cx + bw) / im_width; - T ymax = (cy + bh) / im_height; - out[i * 4] = is_clip ? clip(xmin) : xmin; - out[i * 4 + 1] = is_clip ? clip(ymin) : ymin; - out[i * 4 + 2] = is_clip ? clip(xmax) : xmax; - out[i * 4 + 3] = is_clip ? clip(ymax) : ymax; - } -} - -template -__global__ void SetVariance(T* out, - const T* var, - const int vnum, - const int num) { - CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; } -} - -template -class PriorBoxOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); - auto* image = ctx.Input("Image"); - auto* boxes = ctx.Output("Boxes"); - auto* vars = ctx.Output("Variances"); - - auto min_sizes = ctx.Attr>("min_sizes"); - auto max_sizes = ctx.Attr>("max_sizes"); - auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); - auto variances = ctx.Attr>("variances"); - auto flip = ctx.Attr("flip"); - auto clip = ctx.Attr("clip"); - auto min_max_aspect_ratios_order = - ctx.Attr("min_max_aspect_ratios_order"); - - std::vector aspect_ratios; - ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); - - T step_w = static_cast(ctx.Attr("step_w")); - T step_h = static_cast(ctx.Attr("step_h")); - T offset = static_cast(ctx.Attr("offset")); - - auto im_width = image->dims()[3]; - auto im_height = image->dims()[2]; - - auto width = input->dims()[3]; - auto height = input->dims()[2]; - - T step_width, step_height; - if (step_w == 0 || step_h == 0) { - step_width = static_cast(im_width) / width; - step_height = static_cast(im_height) / height; - } else { - step_width = step_w; - step_height = step_h; - } - - int num_priors = aspect_ratios.size() * min_sizes.size(); - if (max_sizes.size() > 0) { - num_priors += max_sizes.size(); - } - int min_num = static_cast(min_sizes.size()); - int box_num = width * height * num_priors; - - int block = 512; - int grid = (box_num + block - 1) / block; - - auto stream = ctx.template device_context().stream(); - - boxes->mutable_data(ctx.GetPlace()); - vars->mutable_data(ctx.GetPlace()); - - framework::Tensor r; - framework::TensorFromVector(aspect_ratios, ctx.device_context(), &r); - - framework::Tensor min; - framework::TensorFromVector(min_sizes, ctx.device_context(), &min); - - T* max_data = nullptr; - framework::Tensor max; - if (max_sizes.size() > 0) { - framework::TensorFromVector(max_sizes, ctx.device_context(), &max); - max_data = max.data(); - } - - GenPriorBox<<>>(boxes->data(), - r.data(), - height, - width, - im_height, - im_width, - aspect_ratios.size(), - offset, - step_width, - step_height, - min.data(), - max_data, - min_num, - clip, - min_max_aspect_ratios_order); - - framework::Tensor v; - framework::TensorFromVector(variances, ctx.device_context(), &v); - grid = (box_num * 4 + block - 1) / block; - SetVariance<<>>( - vars->data(), v.data(), variances.size(), box_num * 4); - } -}; // namespace operators - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fill_diagonal_tensor_op.cc b/paddle/fluid/operators/fill_diagonal_tensor_op.cc index d2e248cffd44c91d667425886b8147f1b9867bc4..ccf9b7aa359389e65ed0c58c9e8ceec072ab1ad5 100644 --- a/paddle/fluid/operators/fill_diagonal_tensor_op.cc +++ b/paddle/fluid/operators/fill_diagonal_tensor_op.cc @@ -12,64 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fill_diagonal_tensor_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { -// calculate the offset\new_dims\(strides of dim1/dim2)\matoffset -void CalMatDims(framework::DDim out_dims, - int dim1, - int dim2, - int64_t *offset, - int64_t *new_dims, - int64_t *strides, - int64_t *matoffset) { - int64_t dimprod = 1, batchdim = 1; - int rank = out_dims.size(); - int matoffidx = 0; - for (int i = rank - 1; i >= 0; i--) { - if (i == dim2) { - strides[0] = dimprod; - } else if (i == dim1) { - strides[1] = dimprod; - } else { - batchdim *= out_dims[i]; - // matoffset calculate the offset position of the diagonal defined by dim1 - // and dim2 - // the first circle calculate the final free dimension - // and then calculate the front free dim one by one - if (matoffidx == 0) { - for (int64_t j = 0; j < out_dims[i]; j++) { - matoffset[matoffidx] = dimprod * j; - matoffidx++; - } - } else { - auto size = matoffidx; - for (int64_t j = 1; j < out_dims[i]; j++) { - for (int64_t k = 0; k < size; k++) { - matoffset[matoffidx] = matoffset[k] + dimprod * j; - matoffidx++; - } - } - } - } - dimprod *= out_dims[i]; - } - - auto diagdim = dim1; - if (*offset >= 0) { - diagdim = std::min(out_dims[dim1], out_dims[dim2] - *offset); - *offset *= strides[0]; - } else { - diagdim = std::min(out_dims[dim1] + *offset, out_dims[dim2]); - *offset *= -strides[1]; - } - new_dims[0] = batchdim; - new_dims[1] = diagdim; - return; -} - class FillDiagonalTensorOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -97,14 +47,6 @@ class FillDiagonalTensorOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillDiagonalTensor"); - OP_INOUT_CHECK( - context->HasOutput("Out"), "Output", "Out", "FillDiagonalTensor"); - auto x_dims = context->GetInputDim("X"); - context->SetOutputDim("Out", x_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -124,77 +66,10 @@ class FillDiagonalTensorOpVarTypeInference } }; -template -class FillDiagonalTensorKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *out = ctx.Output("Out"); - auto *srctensor = ctx.Input("Y"); - auto dim1 = ctx.Attr("dim1"); - auto dim2 = ctx.Attr("dim2"); - auto offset = ctx.Attr("offset"); - auto *xin = ctx.Input("X"); - - T *out_data = out->mutable_data(ctx.GetPlace()); - const T *fill_data = srctensor->data(); - - framework::TensorCopy(*xin, ctx.GetPlace(), out); - auto out_dims = out->dims(); - auto matdims = srctensor->dims(); - auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1); - - int64_t new_dims[2], strides[2]; - std::vector matdim; - matdim.resize(fill_dims[0]); - CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim.data()); - PADDLE_ENFORCE_EQ( - new_dims[0], - fill_dims[0], - platform::errors::InvalidArgument("The dims should be %d x %d, but get " - "%d x %d in fill tensor Y", - new_dims[0], - new_dims[1], - fill_dims[0], - fill_dims[1])); - PADDLE_ENFORCE_EQ( - new_dims[1], - fill_dims[1], - platform::errors::InvalidArgument("The dims should be %d x %d, but get " - "%d x %d in fill tensor Y", - new_dims[0], - new_dims[1], - fill_dims[0], - fill_dims[1])); - - auto size = out->numel(); - for (int64_t i = 0; i < fill_dims[0]; i += 1) { - auto sumoff = matdim[i] + offset; - for (int64_t j = 0; j < fill_dims[1]; j += 1) { - auto fill_index = j * (strides[1] + strides[0]) + sumoff; - if (fill_index < size) { - out_data[fill_index] = fill_data[i * fill_dims[1] + j]; - } - } - } - } -}; - class FillDiagonalTensorGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "mul"); - auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { // Note: don't get data type from ctx.Input("Input"); @@ -219,50 +94,6 @@ class FillDiagonalTensorGradOpMaker : public framework::SingleGradOpMaker { } }; -template -class FillDiagonalTensorGradKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dout = ctx.Input(framework::GradVarName("Out")); - - auto dim1 = ctx.Attr("dim1"); - auto dim2 = ctx.Attr("dim2"); - auto offset = ctx.Attr("offset"); - auto matrows = 1; - - if (dx) { - auto *data = dx->mutable_data(ctx.GetPlace()); - - auto dx_dims = dx->dims(); - for (int i = 0; i < dx_dims.size(); i++) { - if (i != dim1 && i != dim2) { - matrows *= dx_dims[i]; - } - } - - int64_t new_dims[2], strides[2]; - std::vector matdim; - matdim.resize(matrows); - CalMatDims( - dx_dims, dim1, dim2, &offset, new_dims, strides, matdim.data()); - - auto size = dx->numel(); - framework::TensorCopy(*dout, ctx.GetPlace(), dx); - - for (int64_t i = 0; i < new_dims[0]; i += 1) { - auto sumoff = matdim[i] + offset; - for (int64_t j = 0; j < new_dims[1]; j += 1) { - auto fill_index = j * (strides[1] + strides[0]) + sumoff; - if (fill_index < size) { - data[fill_index] = 0; - } - } - } - } - } -}; - DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorGradOpInplaceInferer, {framework::GradVarName("Out"), @@ -272,41 +103,25 @@ DECLARE_INPLACE_OP_INFERER(FillDiagonalTensorGradOpInplaceInferer, } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal_tensor, + FillDiagonalTensorInferShapeFunctor, + PD_INFER_META(phi::FillDiagonalTensorInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR( + fill_diagonal_tensor_grad, + FillDiagonalTensorGradInferShapeFunctor, + PD_INFER_META(phi::FillDiagonalTensorGradInferMeta)); + REGISTER_OPERATOR( fill_diagonal_tensor, ops::FillDiagonalTensorOp, - ops::FillDiagonalTensorOpMaker, - ops::FillDiagonalTensorOpVarTypeInference, ops::FillDiagonalTensorGradOpMaker, ops::FillDiagonalTensorGradOpMaker, - ops::FillDiagonalTensorOpInplaceInferer); + ops::FillDiagonalTensorOpMaker, + ops::FillDiagonalTensorOpInplaceInferer, + ops::FillDiagonalTensorOpVarTypeInference, + FillDiagonalTensorInferShapeFunctor); REGISTER_OPERATOR(fill_diagonal_tensor_grad, ops::FillDiagonalTensorGradOp, - ops::FillDiagonalTensorGradOpInplaceInferer); - -REGISTER_OP_CPU_KERNEL( - fill_diagonal_tensor, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel, - ops::FillDiagonalTensorKernel>, - ops::FillDiagonalTensorKernel>, - ops::FillDiagonalTensorKernel); - -REGISTER_OP_CPU_KERNEL( - fill_diagonal_tensor_grad, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel, - ops::FillDiagonalTensorGradKernel>, - ops::FillDiagonalTensorGradKernel>, - ops::FillDiagonalTensorGradKernel); + ops::FillDiagonalTensorGradOpInplaceInferer, + FillDiagonalTensorGradInferShapeFunctor); diff --git a/paddle/fluid/operators/fill_diagonal_tensor_op.cu b/paddle/fluid/operators/fill_diagonal_tensor_op.cu deleted file mode 100644 index 1b6ab71386b3b1cd5e5db4043f25ef78a13d07a4..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fill_diagonal_tensor_op.cu +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fill_diagonal_tensor_op.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void fill_diagonal_tensor_kernel(int64_t size, - T *out_data, - const T *fill_data, - int64_t *strides, - int64_t *matdim, - int64_t offset, - int64_t fill_dims0, - int64_t fill_dims1) { - int64_t i = blockIdx.x; - auto sumoff = matdim[i] + offset; - for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) { - auto fill_index = j * (strides[1] + strides[0]) + sumoff; - if (fill_index < size) { - out_data[fill_index] = fill_data[i * fill_dims1 + j]; - } - } -} - -template -__global__ void fill_grad_kernel(int64_t size, - T *out_data, - int64_t *strides, - int64_t *matdim, - int64_t offset, - int64_t fill_dims0, - int64_t fill_dims1) { - int64_t i = blockIdx.x; - auto sumoff = matdim[i] + offset; - for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) { - auto fill_index = j * (strides[1] + strides[0]) + sumoff; - if (fill_index < size) { - out_data[fill_index] = T(0); - } - } -} - -template -class FillDiagonalTensorCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#ifdef __HIPCC__ - const int64_t kMaxBlockDim = 256; -#else - const int64_t kMaxBlockDim = 512; -#endif - auto *out = ctx.Output("Out"); - auto *srctensor = ctx.Input("Y"); - auto dim1 = ctx.Attr("dim1"); - auto dim2 = ctx.Attr("dim2"); - auto offset = ctx.Attr("offset"); - - auto *xin = ctx.Input("X"); - framework::TensorCopy(*xin, ctx.GetPlace(), out); - - T *out_data = out->mutable_data(ctx.GetPlace()); - const T *fill_data = srctensor->data(); - - auto out_dims = out->dims(); - auto matdims = srctensor->dims(); - auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1); - - int64_t new_dims[2]; - std::vector memory_block; - memory_block.resize(2 + fill_dims[0]); - int64_t *strides = &(memory_block[0]); - int64_t *matdim = &(memory_block[2]); - CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim); - PADDLE_ENFORCE_EQ( - new_dims[0], - fill_dims[0], - platform::errors::InvalidArgument("The dims should be %d x %d, but get " - "%d x %d in fill tensor Y", - new_dims[0], - new_dims[1], - fill_dims[0], - fill_dims[1])); - PADDLE_ENFORCE_EQ( - new_dims[1], - fill_dims[1], - platform::errors::InvalidArgument("The dims should be %d x %d, but get " - "%d x %d in fill tensor Y", - new_dims[0], - new_dims[1], - fill_dims[0], - fill_dims[1])); - - auto size = out->numel(); - - auto &dev_ctx = ctx.template device_context(); - auto stream = dev_ctx.stream(); - Tensor tensor_tmp; - int64_t *memory_block_cu = - tensor_tmp.mutable_data({2 + fill_dims[0]}, ctx.GetPlace()); - const auto gpu_place = ctx.GetPlace(); - memory::Copy(gpu_place, - memory_block_cu, - platform::CPUPlace(), - memory_block.data(), - sizeof(int64_t) * (2 + fill_dims[0]), - stream); - - int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2]; - - auto kGridDim = new_dims[0]; - auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim); - fill_diagonal_tensor_kernel - <<>>(size, - out_data, - fill_data, - strides_cu, - matdim_cu, - offset, - fill_dims[0], - fill_dims[1]); - } -}; - -template -class FillDiagonalTensorGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#ifdef __HIPCC__ - const int64_t kMaxBlockDim = 256; -#else - const int64_t kMaxBlockDim = 512; -#endif - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dout = ctx.Input(framework::GradVarName("Out")); - - auto dim1 = ctx.Attr("dim1"); - auto dim2 = ctx.Attr("dim2"); - auto offset = ctx.Attr("offset"); - auto matrows = 1; - - if (dx) { - auto *data = dx->mutable_data(ctx.GetPlace()); - auto dx_dims = dx->dims(); - framework::TensorCopy(*dout, ctx.GetPlace(), dx); - - for (int i = 0; i < dx_dims.size(); i++) { - if (i != dim1 && i != dim2) { - matrows *= dx_dims[i]; - } - } - - int64_t new_dims[2]; - std::vector memory_block; - memory_block.resize(2 + matrows); - int64_t *strides = &memory_block[0]; - int64_t *matdim = &memory_block[2]; - CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim); - - auto size = dx->numel(); - - auto &dev_ctx = ctx.template device_context(); - auto stream = dev_ctx.stream(); - Tensor tensor_tmp; - int64_t *memory_block_cu = - tensor_tmp.mutable_data({2 + matrows}, ctx.GetPlace()); - const auto gpu_place = ctx.GetPlace(); - memory::Copy(gpu_place, - memory_block_cu, - platform::CPUPlace(), - memory_block.data(), - sizeof(int64_t) * (2 + matrows), - stream); - - int64_t *strides_cu = &memory_block_cu[0], - *matdim_cu = &memory_block_cu[2]; - - auto kGridDim = new_dims[0]; - auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim); - fill_grad_kernel<<>>( - size, data, strides_cu, matdim_cu, offset, new_dims[0], new_dims[1]); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - fill_diagonal_tensor, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel, - ops::FillDiagonalTensorCUDAKernel>, - ops::FillDiagonalTensorCUDAKernel>, - ops::FillDiagonalTensorCUDAKernel); - -REGISTER_OP_CUDA_KERNEL( - fill_diagonal_tensor_grad, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel, - ops::FillDiagonalTensorGradCUDAKernel>, - ops::FillDiagonalTensorGradCUDAKernel>, - ops::FillDiagonalTensorGradCUDAKernel); diff --git a/paddle/fluid/operators/fill_diagonal_tensor_op.h b/paddle/fluid/operators/fill_diagonal_tensor_op.h deleted file mode 100644 index f3e41a9c9332ca7df8a7c871914cafe69edab4cb..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fill_diagonal_tensor_op.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -void CalMatDims(framework::DDim out_dims, - int dim1, - int dim2, - int64_t *offset, - int64_t *new_dims, - int64_t *strides, - int64_t *matoffset); - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 979c944a7306c6b65ce38914cdb649a882495e6a..b1ae62603d6a04ac63ea99af0937371657ca8df0 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -874,6 +874,16 @@ inplace : (x -> out) backward : fill_diagonal_grad +- api : fill_diagonal_tensor + args : (Tensor x, Tensor y, int64_t offset, int dim1, int dim2) + output : Tensor(out) + infer_meta : + func : FillDiagonalTensorInferMeta + kernel : + func : fill_diagonal_tensor + inplace : (x -> out) + backward : fill_diagonal_tensor_grad + - api : flatten args : (Tensor x, int start_axis, int stop_axis) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 892e824e60a13d3d4cf2d6463774137d4acde53f..2a5d8bff70cd9427baf2984ce414af38942c3d2b 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -820,6 +820,15 @@ func : FillDiagonalGradInferMeta kernel : func : fill_diagonal_grad + +- backward_api : fill_diagonal_tensor_grad + forward : fill_diagonal_tensor (Tensor x, Tensor y, int64_t offset, int dim1, int dim2) -> Tensor(out) + args : (Tensor out_grad, int64_t offset, int dim1, int dim2) + output : Tensor(x_grad) + infer_meta : + func : FillDiagonalTensorGradInferMeta + kernel : + func : fill_diagonal_tensor_grad inplace : (out_grad -> x_grad) - backward_api : flatten_grad diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 87640cdddbfc9353c7cbd5f6872b687c75dd0ebe..e3898adf56c55803ab41d1e7b4853e5e644ac50c 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -297,6 +297,17 @@ void FillDiagonalGradInferMeta(const MetaTensor& dout, } } +void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad, + int64_t offset, + int dim1, + int dim2, + MetaTensor* x_grad) { + if (x_grad != nullptr) { + x_grad->set_dims(out_grad.dims()); + x_grad->set_dtype(out_grad.dtype()); + } +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 1ada2c80157942c0a15df9d388322f39a49e590b..15ab16eff1c19d55619a05ddafa23865804784ba 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -140,6 +140,12 @@ void EigvalshGradInferMeta(const MetaTensor& out_v, void FillDiagonalGradInferMeta( const MetaTensor& dout, float value, int offset, bool wrap, MetaTensor* dx); +void FillDiagonalTensorGradInferMeta(const MetaTensor& out_grad, + int64_t offset, + int dim1, + int dim2, + MetaTensor* x_grad); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 48b76f96c019fff9403be9271e6af9a00e3454e3..5211e7e10e671c79f77fa53ee249ef7c2e794af1 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1174,6 +1174,19 @@ void ExpandAsInferMeta(const MetaTensor& x, #undef MAX_RANK_SUPPORTED } +void FillDiagonalTensorInferMeta(const MetaTensor& x, + const MetaTensor& y, + int64_t offset, + int dim1, + int dim2, + MetaTensor* out) { + PADDLE_ENFORCE_NOT_NULL(out, + phi::errors::InvalidArgument( + "Output Tensor (out) should not be nullptr.")); + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); +} + void GatherInferMeta(const MetaTensor& x, const MetaTensor& index, const Scalar& axis, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 68c020cc68a61499f6a6c659c0e1bf9d519f5f0a..ab4d5b03cae20cd9fcb9dbcbf11f8987deec2e9b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -182,6 +182,13 @@ void ExpandAsInferMeta(const MetaTensor& x, const std::vector& target_shape, MetaTensor* out); +void FillDiagonalTensorInferMeta(const MetaTensor& x, + const MetaTensor& y, + int64_t offset, + int dim1, + int dim2, + MetaTensor* out); + void GatherInferMeta(const MetaTensor& x, const MetaTensor& index, const Scalar& axis, diff --git a/paddle/phi/kernels/cpu/fill_diagonal_tensor_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_diagonal_tensor_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..318e2016097fe828954d3fca282d77af7b45085c --- /dev/null +++ b/paddle/phi/kernels/cpu/fill_diagonal_tensor_grad_kernel.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void FillDiagonalTensorGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int64_t offset, + int dim1, + int dim2, + DenseTensor* x_grad) { + auto matrows = 1; + + if (x_grad) { + auto* data = ctx.template Alloc(x_grad); + + auto dx_dims = x_grad->dims(); + for (int i = 0; i < dx_dims.size(); i++) { + if (i != dim1 && i != dim2) { + matrows *= dx_dims[i]; + } + } + + int64_t new_dims[2], strides[2]; + std::vector matdim; + matdim.resize(matrows); + CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim.data()); + + auto size = x_grad->numel(); + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + + for (int64_t i = 0; i < new_dims[0]; i += 1) { + auto sumoff = matdim[i] + offset; + for (int64_t j = 0; j < new_dims[1]; j += 1) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + data[fill_index] = 0; + } + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal_tensor_grad, + CPU, + ALL_LAYOUT, + phi::FillDiagonalTensorGradKernel, + float, + double, + int64_t, + int, + int8_t, + uint8_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex, + bool) {} diff --git a/paddle/phi/kernels/cpu/fill_diagonal_tensor_kernel.cc b/paddle/phi/kernels/cpu/fill_diagonal_tensor_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e8030199d16a418e271d7566a6dc586c9e6977e --- /dev/null +++ b/paddle/phi/kernels/cpu/fill_diagonal_tensor_kernel.cc @@ -0,0 +1,140 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { + +void CalMatDims(phi::DDim out_dims, + int dim1, + int dim2, + int64_t *offset, + int64_t *new_dims, + int64_t *strides, + int64_t *matoffset) { + int64_t dimprod = 1, batchdim = 1; + int rank = out_dims.size(); + int matoffidx = 0; + for (int i = rank - 1; i >= 0; i--) { + if (i == dim2) { + strides[0] = dimprod; + } else if (i == dim1) { + strides[1] = dimprod; + } else { + batchdim *= out_dims[i]; + // matoffset calculate the offset position of the diagonal defined by dim1 + // and dim2 + // the first circle calculate the final free dimension + // and then calculate the front free dim one by one + if (matoffidx == 0) { + for (int64_t j = 0; j < out_dims[i]; j++) { + matoffset[matoffidx] = dimprod * j; + matoffidx++; + } + } else { + auto size = matoffidx; + for (int64_t j = 1; j < out_dims[i]; j++) { + for (int64_t k = 0; k < size; k++) { + matoffset[matoffidx] = matoffset[k] + dimprod * j; + matoffidx++; + } + } + } + } + dimprod *= out_dims[i]; + } + + auto diagdim = dim1; + if (*offset >= 0) { + diagdim = std::min(out_dims[dim1], out_dims[dim2] - *offset); + *offset *= strides[0]; + } else { + diagdim = std::min(out_dims[dim1] + *offset, out_dims[dim2]); + *offset *= -strides[1]; + } + new_dims[0] = batchdim; + new_dims[1] = diagdim; + return; +} + +template +void FillDiagonalTensorKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &y, + int64_t offset, + int dim1, + int dim2, + DenseTensor *out) { + T *out_data = ctx.template Alloc(out); + const T *fill_data = y.data(); + + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + auto out_dims = out->dims(); + auto matdims = y.dims(); + auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1); + + int64_t new_dims[2], strides[2]; + std::vector matdim; + matdim.resize(fill_dims[0]); + CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim.data()); + PADDLE_ENFORCE_EQ( + new_dims[0], + fill_dims[0], + errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], + new_dims[1], + fill_dims[0], + fill_dims[1])); + PADDLE_ENFORCE_EQ( + new_dims[1], + fill_dims[1], + errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], + new_dims[1], + fill_dims[0], + fill_dims[1])); + + auto size = out->numel(); + for (int64_t i = 0; i < fill_dims[0]; i += 1) { + auto sumoff = matdim[i] + offset; + for (int64_t j = 0; j < fill_dims[1]; j += 1) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + out_data[fill_index] = fill_data[i * fill_dims[1] + j]; + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal_tensor, + CPU, + ALL_LAYOUT, + phi::FillDiagonalTensorKernel, + float, + double, + int64_t, + int, + int8_t, + uint8_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex, + bool) {} diff --git a/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc b/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc index ff6e2a372791e6fd47ff4394bdd78fe21e555d84..dc82ffbea8791c46190c7d381e5d35507e9f9693 100644 --- a/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc +++ b/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc @@ -252,7 +252,6 @@ void SliceOneClass(const Context& ctx, const DenseTensor& items, const int class_id, DenseTensor* one_class_item) { - // T* item_data = one_class_item->mutable_data(ctx.GetPlace()); T* item_data = ctx.template Alloc(one_class_item); const T* items_data = items.data(); const int64_t num_item = items.dims()[0]; diff --git a/paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h b/paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c44d782593d9dd6dc5bf2693e4b5b4864d9b5228 --- /dev/null +++ b/paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FillDiagonalTensorGradKernel(const Context &ctx, + const DenseTensor &out_grad, + int64_t offset, + int dim1, + int dim2, + DenseTensor *x_grad); + +void CalMatDims(phi::DDim out_dims, + int dim1, + int dim2, + int64_t *offset, + int64_t *new_dims, + int64_t *strides, + int64_t *matoffset); + +} // namespace phi diff --git a/paddle/phi/kernels/fill_diagonal_tensor_kernel.h b/paddle/phi/kernels/fill_diagonal_tensor_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9d6c8da93edb52af20b152ed8a08c896ec4ef800 --- /dev/null +++ b/paddle/phi/kernels/fill_diagonal_tensor_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FillDiagonalTensorKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int64_t offset, + int dim1, + int dim2, + DenseTensor* out); + +void CalMatDims(phi::DDim out_dims, + int dim1, + int dim2, + int64_t* offset, + int64_t* new_dims, + int64_t* strides, + int64_t* matoffset); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0e302b23ee98cec7d2d76e1f6e213c340053c1b6 --- /dev/null +++ b/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_tensor_grad_kernel.h" + +#include +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void fill_grad_kernel(int64_t size, + T *out_data, + int64_t *strides, + int64_t *matdim, + int64_t offset, + int64_t fill_dims0, + int64_t fill_dims1) { + int64_t i = blockIdx.x; + auto sumoff = matdim[i] + offset; + for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + out_data[fill_index] = T(0); + } + } +} + +template +void FillDiagonalTensorGradKernel(const Context &ctx, + const DenseTensor &out_grad, + int64_t offset, + int dim1, + int dim2, + DenseTensor *x_grad) { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + auto matrows = 1; + + if (x_grad) { + auto *data = ctx.template Alloc(x_grad); + auto dx_dims = x_grad->dims(); + phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad); + + for (int i = 0; i < dx_dims.size(); i++) { + if (i != dim1 && i != dim2) { + matrows *= dx_dims[i]; + } + } + + int64_t new_dims[2]; + std::vector memory_block; + memory_block.resize(2 + matrows); + int64_t *strides = &memory_block[0]; + int64_t *matdim = &memory_block[2]; + CalMatDims(dx_dims, dim1, dim2, &offset, new_dims, strides, matdim); + + auto size = x_grad->numel(); + + auto stream = ctx.stream(); + DenseTensor tensor_tmp; + tensor_tmp.Resize(phi::make_ddim({2 + matrows})); + int64_t *memory_block_cu = ctx.template Alloc(&tensor_tmp); + const auto gpu_place = ctx.GetPlace(); + paddle::memory::Copy(gpu_place, + memory_block_cu, + CPUPlace(), + memory_block.data(), + sizeof(int64_t) * (2 + matrows), + stream); + + int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2]; + + auto kGridDim = new_dims[0]; + auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim); + fill_grad_kernel<<>>( + size, data, strides_cu, matdim_cu, offset, new_dims[0], new_dims[1]); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal_tensor_grad, + GPU, + ALL_LAYOUT, + phi::FillDiagonalTensorGradKernel, + float, + double, + int64_t, + int, + int8_t, + uint8_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex, + bool) {} diff --git a/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu b/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..739a8666e314364f632de8ec8cde6b3de24ef044 --- /dev/null +++ b/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu @@ -0,0 +1,136 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h" + +#include +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { + +template +__global__ void fill_diagonal_tensor_kernel(int64_t size, + T *out_data, + const T *fill_data, + int64_t *strides, + int64_t *matdim, + int64_t offset, + int64_t fill_dims0, + int64_t fill_dims1) { + int64_t i = blockIdx.x; + auto sumoff = matdim[i] + offset; + for (int64_t j = threadIdx.x; j < fill_dims1; j += blockDim.x) { + auto fill_index = j * (strides[1] + strides[0]) + sumoff; + if (fill_index < size) { + out_data[fill_index] = fill_data[i * fill_dims1 + j]; + } + } +} + +template +void FillDiagonalTensorKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &y, + int64_t offset, + int dim1, + int dim2, + DenseTensor *out) { +#ifdef __HIPCC__ + const int64_t kMaxBlockDim = 256; +#else + const int64_t kMaxBlockDim = 512; +#endif + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + + T *out_data = ctx.template Alloc(out); + const T *fill_data = y.data(); + + auto out_dims = out->dims(); + auto matdims = y.dims(); + auto fill_dims = phi::flatten_to_2d(matdims, matdims.size() - 1); + + int64_t new_dims[2]; + std::vector memory_block; + memory_block.resize(2 + fill_dims[0]); + int64_t *strides = &(memory_block[0]); + int64_t *matdim = &(memory_block[2]); + CalMatDims(out_dims, dim1, dim2, &offset, new_dims, strides, matdim); + PADDLE_ENFORCE_EQ( + new_dims[0], + fill_dims[0], + errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], + new_dims[1], + fill_dims[0], + fill_dims[1])); + PADDLE_ENFORCE_EQ( + new_dims[1], + fill_dims[1], + errors::InvalidArgument("The dims should be %d x %d, but get " + "%d x %d in fill tensor Y", + new_dims[0], + new_dims[1], + fill_dims[0], + fill_dims[1])); + + auto size = out->numel(); + + auto stream = ctx.stream(); + DenseTensor tensor_tmp; + tensor_tmp.Resize(phi::make_ddim({2 + fill_dims[0]})); + int64_t *memory_block_cu = ctx.template Alloc(&tensor_tmp); + const auto gpu_place = ctx.GetPlace(); + paddle::memory::Copy(gpu_place, + memory_block_cu, + CPUPlace(), + memory_block.data(), + sizeof(int64_t) * (2 + fill_dims[0]), + stream); + + int64_t *strides_cu = &memory_block_cu[0], *matdim_cu = &memory_block_cu[2]; + + auto kGridDim = new_dims[0]; + auto kBlockDim = std::min(int64_t(new_dims[1]), kMaxBlockDim); + fill_diagonal_tensor_kernel + <<>>(size, + out_data, + fill_data, + strides_cu, + matdim_cu, + offset, + fill_dims[0], + fill_dims[1]); +} + +} // namespace phi + +PD_REGISTER_KERNEL(fill_diagonal_tensor, + GPU, + ALL_LAYOUT, + phi::FillDiagonalTensorKernel, + float, + double, + int64_t, + int, + int8_t, + uint8_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex, + bool) {} diff --git a/paddle/phi/ops/compat/fill_diagonal_tensor_sig.cc b/paddle/phi/ops/compat/fill_diagonal_tensor_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..56b3c2ab81a9bc219a65051eaeebdbc8bb4b0cd8 --- /dev/null +++ b/paddle/phi/ops/compat/fill_diagonal_tensor_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FillDiagonalTensorOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "fill_diagonal_tensor", {"X", "Y"}, {"offset", "dim1", "dim2"}, {"Out"}); +} + +KernelSignature FillDiagonalTensorGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fill_diagonal_tensor_grad", + {"Out@GRAD"}, + {"offset", "dim1", "dim2"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_tensor, + phi::FillDiagonalTensorOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_tensor_grad, + phi::FillDiagonalTensorGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 160b4e2e6857df522a128790e744e63cb9f0041c..b0274431d453afdd090943abadb17f1c8123e27e 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1457,7 +1457,6 @@ class OpTest(unittest.TestCase): # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng if expect_np.size == 0: self.op_test.assertTrue(actual_np.size == 0) # }}} - # print("actual_np, expect_np", actual_np, expect_np) self._compare_numpy(name, actual_np, expect_np) if isinstance(expect, tuple): self._compare_list(name, actual, expect) diff --git a/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py b/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py index c1a187d7bbaafd945ea5aebc7ef98121e4ce5a00..a35dd611cb2e9ce0803b0fdded40ccead1d77227 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_diagonal_tensor_op.py @@ -86,6 +86,7 @@ class TensorFillDiagTensor_Test(OpTest): def setUp(self): self.op_type = "fill_diagonal_tensor" + self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor self.init_kernel_type() x = np.random.random((10, 10)).astype(self.dtype) y = np.random.random((10, )).astype(self.dtype) @@ -96,22 +97,23 @@ class TensorFillDiagTensor_Test(OpTest): self.inputs = {"X": x, "Y": y} self.outputs = {'Out': out} - self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset} + self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2} def init_kernel_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test): def setUp(self): self.op_type = "fill_diagonal_tensor" + self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor self.init_kernel_type() x = np.random.random((2, 20, 25)).astype(self.dtype) y = np.random.random((2, 20)).astype(self.dtype) @@ -122,7 +124,7 @@ class TensorFillDiagTensor_Test2(TensorFillDiagTensor_Test): self.inputs = {"X": x, "Y": y} self.outputs = {'Out': out} - self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset} + self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2} def init_kernel_type(self): self.dtype = np.float32 @@ -132,6 +134,7 @@ class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test): def setUp(self): self.op_type = "fill_diagonal_tensor" + self.python_api = paddle.tensor.manipulation.fill_diagonal_tensor self.init_kernel_type() x = np.random.random((2, 20, 20, 3)).astype(self.dtype) y = np.random.random((2, 3, 18)).astype(self.dtype) @@ -142,11 +145,12 @@ class TensorFillDiagTensor_Test3(TensorFillDiagTensor_Test): self.inputs = {"X": x, "Y": y} self.outputs = {'Out': out} - self.attrs = {"dim1": dim1, "dim2": dim2, "offset": offset} + self.attrs = {"offset": offset, "dim1": dim1, "dim2": dim2} def init_kernel_type(self): self.dtype = np.float16 if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py index e71cc3b7239f1cf70c7bb74316914bdeffaa803e..37fee3a380fbdcae570f829496513b33b4ddf123 100644 --- a/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_tensor_fill_diagonal_tensor.py @@ -18,6 +18,7 @@ import unittest import numpy as np import six import paddle +from paddle.fluid.framework import _enable_legacy_dygraph class TensorFillDiagTensor_Test(unittest.TestCase): @@ -183,5 +184,9 @@ class TensorFillDiagTensor_Test(unittest.TestCase): fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) +class TensorFillDiagTensor_Test_legacy(TensorFillDiagTensor_Test): + _enable_legacy_dygraph() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ea7ef4ff9d724aa9e2e0d06f806c3b4ed462cda1..a280baa17b1a5433e8d8b18478859b83581892ab 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -888,10 +888,17 @@ def _fill_diagonal_tensor_impl(x, y, offset=0, dim1=0, dim2=1, inplace=False): y = y.reshape([1, -1]) if inplace: - return _C_ops.fill_diagonal_tensor_(x, y, 'dim1', dim1, 'dim2', dim2, - 'offset', offset) - return _C_ops.fill_diagonal_tensor(x, y, 'dim1', dim1, 'dim2', dim2, - 'offset', offset) + if in_dygraph_mode(): + return _C_ops.final_state_fill_diagonal_tensor_( + x, y, offset, dim1, dim2) + else: + return _C_ops.fill_diagonal_tensor_(x, y, 'offset', offset, 'dim1', + dim1, 'dim2', dim2) + if in_dygraph_mode(): + return _C_ops.final_state_fill_diagonal_tensor(x, y, offset, dim1, dim2) + else: + return _C_ops.fill_diagonal_tensor(x, y, 'offset', offset, 'dim1', dim1, + 'dim2', dim2) def fill_diagonal_tensor_(x, y, offset=0, dim1=0, dim2=1, name=None):