/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fill_diagonal_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using CUDADeviceContext = paddle::platform::CUDADeviceContext; template __global__ void fill_constant_kernel(const int64_t featuresize, T* in_data, int64_t strides, int offset, T fillvar, int dims) { for (int64_t idx = blockIdx.x * featuresize + threadIdx.x; idx * strides + offset < (blockIdx.x + 1) * featuresize; idx += blockDim.x) { // to check if the new position with offset is still in the same line; // this modify should not affect across lines. // out_dims[1] is also work for tensor with dim>2, for which the dims must // be the same number if ((idx * strides) % dims + offset < dims && (idx * strides) % dims + offset >= 0) { in_data[idx * strides + offset] = fillvar; } } } template class FillIDiagonalCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #ifdef __HIPCC__ const int64_t kMaxBlockDim = 256; #else const int64_t kMaxBlockDim = 512; #endif auto* out = ctx.Output("Out"); auto offset = ctx.Attr("offset"); auto wrap = ctx.Attr("wrap"); auto* xin = ctx.Input("X"); framework::TensorCopy(*xin, ctx.GetPlace(), out); T* out_data = out->mutable_data(ctx.GetPlace()); auto fill_val = static_cast(ctx.template Attr("value")); T temp_var = static_cast(fill_val); auto size = out->numel(); auto out_dims = out->dims(); auto strides = CalStride(out_dims); // The wrap mode supported only the dims equels to 2; In wrap mode, the // value will be filled in cycles if (!wrap) { size = std::min(size, out_dims[1] * out_dims[1]); } int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim); fill_constant_kernel<<<1, kBlockDim, 0>>>(size, out_data, strides, offset, temp_var, out_dims[1]); } }; template class FillIDiagonalGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #ifdef __HIPCC__ const int64_t kMaxBlockDim = 256; #else const int64_t kMaxBlockDim = 512; #endif auto* dx = ctx.Output(framework::GradVarName("X")); auto* in_data = dx->mutable_data(ctx.GetPlace()); auto* dout = ctx.Input(framework::GradVarName("Out")); auto offset = ctx.Attr("offset"); auto wrap = ctx.Attr("wrap"); framework::TensorCopy(*dout, ctx.GetPlace(), dx); auto size = dx->numel(); auto out_dims = dx->dims(); auto strides = CalStride(out_dims); auto wrapsize = std::min(size, out_dims[1] * out_dims[1]); // The wrap mode supported only the dims equels to 2; In wrap mode, the // value will be filled in cycles if (wrap) { wrapsize = size; } int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim); fill_constant_kernel<<<1, kBlockDim, 0>>>(wrapsize, in_data, strides, offset, T(0), out_dims[1]); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(fill_diagonal, ops::FillIDiagonalCUDAKernel, ops::FillIDiagonalCUDAKernel, ops::FillIDiagonalCUDAKernel, ops::FillIDiagonalCUDAKernel, ops::FillIDiagonalCUDAKernel, ops::FillIDiagonalCUDAKernel); REGISTER_OP_CUDA_KERNEL(fill_diagonal_grad, ops::FillIDiagonalGradCUDAKernel, ops::FillIDiagonalGradCUDAKernel, ops::FillIDiagonalGradCUDAKernel, ops::FillIDiagonalGradCUDAKernel, ops::FillIDiagonalGradCUDAKernel, ops::FillIDiagonalGradCUDAKernel);