未验证 提交 9f1616a0 编写于 作者: Z zhiboniu 提交者: GitHub

Phi fill diagonal (#44453)

* phi_fill_diagonal

* remove old kernels

* update

* update attr args

* refix

* update
上级 80ca78a2
......@@ -12,22 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_diagonal_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
int64_t CalStride(framework::DDim dim) {
int rank = dim.size();
int64_t dimsum = 1;
int64_t strides = 0;
for (int i = rank - 1; i >= 0; i--) {
strides += dimsum;
dimsum *= dim[i];
}
return strides;
}
class FillIDiagonalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -57,13 +49,6 @@ class FillIDiagonalOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillIDiagonal");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillIDiagonal");
auto x_dims = context->GetInputDim("X");
context->SetOutputDim("Out", x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -82,61 +67,10 @@ class FillIDiagonalOpVarTypeInference : public framework::VarTypeInference {
}
};
template <typename T>
class FillIDiagonalKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto fill_val = ctx.template Attr<float>("value");
auto *out = ctx.Output<framework::Tensor>("Out");
auto offset = ctx.Attr<int>("offset");
auto wrap = ctx.Attr<bool>("wrap");
auto *xin = ctx.Input<framework::Tensor>("X");
T temp_var = static_cast<T>(fill_val);
T *out_data = out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*xin, ctx.GetPlace(), out);
auto out_dims = out->dims();
auto strides = CalStride(out_dims);
auto size = out->numel();
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (!wrap) {
size = std::min(size, out_dims[1] * out_dims[1]);
}
for (int64_t i = 0; i < size; i += strides) {
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if (i % out_dims[1] + offset >= 0 &&
i % out_dims[1] + offset < out_dims[1]) {
out_data[i + offset] = temp_var;
}
}
}
};
class FillIDiagonalGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"mul");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
// Note: don't get data type from ctx.Input<framework::Tensor>("Input");
......@@ -160,41 +94,6 @@ class FillIDiagonalGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class FillIDiagonalGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto offset = ctx.Attr<int>("offset");
auto wrap = ctx.Attr<bool>("wrap");
if (dx) {
auto *data = dx->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*dout, ctx.GetPlace(), dx);
auto dx_dims = dx->dims();
auto strides = CalStride(dx_dims);
auto size = dx->numel();
auto wrapsize = std::min(size, dx_dims[1] * dx_dims[1]);
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (wrap) {
wrapsize = size;
}
for (int64_t i = 0; i < wrapsize; i += strides) {
if (i % dx_dims[1] + offset >= 0 &&
i % dx_dims[1] + offset < dx_dims[1]) {
data[i + offset] = T(0);
}
}
}
}
};
DECLARE_INPLACE_OP_INFERER(FillIDiagonalOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer,
{framework::GradVarName("Out"),
......@@ -204,30 +103,24 @@ DECLARE_INPLACE_OP_INFERER(FillIDiagonalGradOpInplaceInferer,
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal,
FillDiagonalShapeFunctor,
PD_INFER_META(phi::FillDiagonalInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(fill_diagonal_grad,
FillDiagonalGradShapeFunctor,
PD_INFER_META(phi::FillDiagonalGradInferMeta));
REGISTER_OPERATOR(fill_diagonal,
ops::FillIDiagonalOp,
ops::FillIDiagonalOpMaker,
ops::FillIDiagonalOpVarTypeInference,
ops::FillIDiagonalGradOpMaker<paddle::framework::OpDesc>,
ops::FillIDiagonalGradOpMaker<paddle::imperative::OpBase>,
ops::FillIDiagonalOpInplaceInferer);
ops::FillIDiagonalOpMaker,
ops::FillIDiagonalOpInplaceInferer,
ops::FillIDiagonalOpVarTypeInference,
FillDiagonalShapeFunctor);
REGISTER_OPERATOR(fill_diagonal_grad,
ops::FillIDiagonalGradOp,
ops::FillIDiagonalGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(fill_diagonal,
ops::FillIDiagonalKernel<float>,
ops::FillIDiagonalKernel<double>,
ops::FillIDiagonalKernel<int64_t>,
ops::FillIDiagonalKernel<int>,
ops::FillIDiagonalKernel<paddle::platform::float16>,
ops::FillIDiagonalKernel<bool>);
REGISTER_OP_CPU_KERNEL(fill_diagonal_grad,
ops::FillIDiagonalGradKernel<float>,
ops::FillIDiagonalGradKernel<double>,
ops::FillIDiagonalGradKernel<int64_t>,
ops::FillIDiagonalGradKernel<int>,
ops::FillIDiagonalGradKernel<paddle::platform::float16>,
ops::FillIDiagonalGradKernel<bool>);
ops::FillIDiagonalGradOpInplaceInferer,
FillDiagonalGradShapeFunctor);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_diagonal_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void fill_constant_kernel(const int64_t featuresize,
T* in_data,
int64_t strides,
int offset,
T fillvar,
int dims) {
for (int64_t idx = blockIdx.x * featuresize + threadIdx.x;
idx * strides + offset < (blockIdx.x + 1) * featuresize;
idx += blockDim.x) {
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if ((idx * strides) % dims + offset < dims &&
(idx * strides) % dims + offset >= 0) {
in_data[idx * strides + offset] = fillvar;
}
}
}
template <typename T>
class FillIDiagonalCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* out = ctx.Output<Tensor>("Out");
auto offset = ctx.Attr<int>("offset");
auto wrap = ctx.Attr<bool>("wrap");
auto* xin = ctx.Input<framework::Tensor>("X");
framework::TensorCopy(*xin, ctx.GetPlace(), out);
T* out_data = out->mutable_data<T>(ctx.GetPlace());
auto fill_val = static_cast<T>(ctx.template Attr<float>("value"));
T temp_var = static_cast<T>(fill_val);
auto size = out->numel();
auto out_dims = out->dims();
auto strides = CalStride(out_dims);
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (!wrap) {
size = std::min(size, out_dims[1] * out_dims[1]);
}
int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim);
fill_constant_kernel<T><<<1, kBlockDim, 0>>>(
size, out_data, strides, offset, temp_var, out_dims[1]);
}
};
template <typename T>
class FillIDiagonalGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* in_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto offset = ctx.Attr<int>("offset");
auto wrap = ctx.Attr<bool>("wrap");
framework::TensorCopy(*dout, ctx.GetPlace(), dx);
auto size = dx->numel();
auto out_dims = dx->dims();
auto strides = CalStride(out_dims);
auto wrapsize = std::min(size, out_dims[1] * out_dims[1]);
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (wrap) {
wrapsize = size;
}
int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim);
fill_constant_kernel<T><<<1, kBlockDim, 0>>>(
wrapsize, in_data, strides, offset, T(0), out_dims[1]);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fill_diagonal,
ops::FillIDiagonalCUDAKernel<float>,
ops::FillIDiagonalCUDAKernel<double>,
ops::FillIDiagonalCUDAKernel<plat::float16>,
ops::FillIDiagonalCUDAKernel<int>,
ops::FillIDiagonalCUDAKernel<int64_t>,
ops::FillIDiagonalCUDAKernel<bool>);
REGISTER_OP_CUDA_KERNEL(fill_diagonal_grad,
ops::FillIDiagonalGradCUDAKernel<float>,
ops::FillIDiagonalGradCUDAKernel<double>,
ops::FillIDiagonalGradCUDAKernel<int>,
ops::FillIDiagonalGradCUDAKernel<int64_t>,
ops::FillIDiagonalGradCUDAKernel<plat::float16>,
ops::FillIDiagonalGradCUDAKernel<bool>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
int64_t CalStride(framework::DDim dim);
} // namespace operators
} // namespace paddle
......@@ -864,6 +864,16 @@
data_type : dtype
backend : place
- api : fill_diagonal
args : (Tensor x, float value, int offset, bool wrap)
output : Tensor(out)
infer_meta :
func : FillDiagonalInferMeta
kernel :
func : fill_diagonal
inplace : (x -> out)
backward : fill_diagonal_grad
- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor(out), Tensor(xshape)
......
......@@ -811,6 +811,15 @@
infer_meta :
func : UnchangedInferMeta
invoke : zeros_like(out_grad, DataType::UNDEFINED, {})
- backward_api : fill_diagonal_grad
forward : fill_diagonal (Tensor x, float value, int offset, bool wrap) -> Tensor(out)
args : (Tensor out_grad, float value, int offset, bool wrap)
output : Tensor(x_grad)
infer_meta :
func : FillDiagonalGradInferMeta
kernel :
func : fill_diagonal_grad
inplace : (out_grad -> x_grad)
- backward_api : flatten_grad
......
......@@ -285,6 +285,18 @@ void EigvalshGradInferMeta(const MetaTensor& out_v,
}
}
void FillDiagonalGradInferMeta(const MetaTensor& dout,
float value,
int offset,
bool wrap,
MetaTensor* dx) {
auto x_dims = dout.dims();
if (dx) {
dx->set_dims(x_dims);
dx->set_dtype(dout.dtype());
}
}
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -137,6 +137,9 @@ void EigvalshGradInferMeta(const MetaTensor& out_v,
bool is_test,
MetaTensor* x_grad);
void FillDiagonalGradInferMeta(
const MetaTensor& dout, float value, int offset, bool wrap, MetaTensor* dx);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -855,6 +855,17 @@ void ExpandInferMeta(const MetaTensor& x,
}
}
void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out) {
PADDLE_ENFORCE_NE(
out,
nullptr,
phi::errors::InvalidArgument("Tensor out should not be null if "));
auto x_dims = x.dims();
out->set_dims(x_dims);
out->set_dtype(x.dtype());
}
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
......@@ -132,6 +132,9 @@ void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);
void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out);
void FlattenInferMeta(const MetaTensor& x,
int start_axis,
int stop_axis,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalGradKernel(const Context& ctx,
const DenseTensor& out_grad,
float value,
int offset,
bool wrap,
DenseTensor* x_grad) {
if (x_grad) {
T* data = ctx.template Alloc<T>(x_grad);
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
auto dx_dims = x_grad->dims();
auto strides = CalStride(dx_dims);
auto size = x_grad->numel();
auto wrapsize = std::min(size, dx_dims[1] * dx_dims[1]);
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (wrap) {
wrapsize = size;
}
for (int64_t i = 0; i < wrapsize; i += strides) {
if (i % dx_dims[1] + offset >= 0 &&
i % dx_dims[1] + offset < dx_dims[1]) {
data[i + offset] = T(0);
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_grad,
CPU,
ALL_LAYOUT,
phi::FillDiagonalGradKernel,
float,
double,
int64_t,
int,
phi::dtype::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalKernel(const Context& ctx,
const DenseTensor& x,
float value,
int offset,
bool wrap,
DenseTensor* out) {
T temp_var = static_cast<T>(value);
T* out_data = ctx.template Alloc<T>(out);
phi::Copy(ctx, x, ctx.GetPlace(), false, out);
auto out_dims = out->dims();
auto strides = CalStride(out_dims);
auto size = out->numel();
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (!wrap) {
size = std::min(size, out_dims[1] * out_dims[1]);
}
for (int64_t i = 0; i < size; i += strides) {
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if (i % out_dims[1] + offset >= 0 &&
i % out_dims[1] + offset < out_dims[1]) {
out_data[i + offset] = temp_var;
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal,
CPU,
ALL_LAYOUT,
phi::FillDiagonalKernel,
float,
double,
int64_t,
int,
phi::dtype::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalGradKernel(const Context& ctx,
const DenseTensor& out_grad,
float value,
int offset,
bool wrap,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/impl/fill_diagonal_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void FillDiagonalKernel(const Context& ctx,
const DenseTensor& x,
float value,
int offset,
bool wrap,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_grad_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void fill_constant_kernel(const int64_t featuresize,
T* in_data,
int64_t strides,
int offset,
T fillvar,
int dims) {
for (int64_t idx = blockIdx.x * featuresize + threadIdx.x;
idx * strides + offset < (blockIdx.x + 1) * featuresize;
idx += blockDim.x) {
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if ((idx * strides) % dims + offset < dims &&
(idx * strides) % dims + offset >= 0) {
in_data[idx * strides + offset] = fillvar;
}
}
}
template <typename T, typename Context>
void FillDiagonalGradKernel(const Context& ctx,
const DenseTensor& out_grad,
float value,
int offset,
bool wrap,
DenseTensor* x_grad) {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* in_data = ctx.template Alloc<T>(x_grad);
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
auto size = x_grad->numel();
auto out_dims = x_grad->dims();
auto strides = CalStride(out_dims);
auto wrapsize = std::min(size, out_dims[1] * out_dims[1]);
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (wrap) {
wrapsize = size;
}
int64_t kBlockDim = std::min(int64_t(size), kMaxBlockDim);
fill_constant_kernel<T><<<1, kBlockDim, 0>>>(
wrapsize, in_data, strides, offset, T(0), out_dims[1]);
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal_grad,
GPU,
ALL_LAYOUT,
phi::FillDiagonalGradKernel,
float,
double,
int64_t,
int,
phi::dtype::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fill_diagonal_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__global__ void fill_constant_kernel(const int64_t featuresize,
T* in_data,
int64_t strides,
int offset,
T fillvar,
int dims) {
for (int64_t idx = blockIdx.x * featuresize + threadIdx.x;
idx * strides + offset < (blockIdx.x + 1) * featuresize;
idx += blockDim.x) {
// to check if the new position with offset is still in the same line;
// this modify should not affect across lines.
// out_dims[1] is also work for tensor with dim>2, for which the dims must
// be the same number
if ((idx * strides) % dims + offset < dims &&
(idx * strides) % dims + offset >= 0) {
in_data[idx * strides + offset] = fillvar;
}
}
}
template <typename T, typename Context>
void FillDiagonalKernel(const Context& ctx,
const DenseTensor& x,
float value,
int offset,
bool wrap,
DenseTensor* out) {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
phi::Copy(ctx, x, ctx.GetPlace(), false, out);
T* out_data = ctx.template Alloc<T>(out);
auto fill_val = static_cast<T>(value);
T temp_var = static_cast<T>(fill_val);
auto size = out->numel();
auto out_dims = out->dims();
auto strides = CalStride(out_dims);
// The wrap mode supported only the dims equels to 2; In wrap mode, the
// value will be filled in cycles
if (!wrap) {
size = std::min(size, out_dims[1] * out_dims[1]);
}
int64_t kBlockDim = std::min(int64_t(size / strides), kMaxBlockDim);
fill_constant_kernel<T><<<1, kBlockDim, 0>>>(
size, out_data, strides, offset, temp_var, out_dims[1]);
}
} // namespace phi
PD_REGISTER_KERNEL(fill_diagonal,
GPU,
ALL_LAYOUT,
phi::FillDiagonalKernel,
float,
double,
int64_t,
int,
phi::dtype::float16,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
inline int64_t CalStride(phi::DDim dim) {
int rank = dim.size();
int64_t dimsum = 1;
int64_t strides = 0;
for (int i = rank - 1; i >= 0; i--) {
strides += dimsum;
dimsum *= dim[i];
}
return strides;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FillDiagonalOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"fill_diagonal", {"X"}, {"value", "offset", "wrap"}, {"Out"});
}
KernelSignature FillDiagonalGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fill_diagonal_grad",
{"Out@GRAD"},
{"value", "offset", "wrap"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fill_diagonal, phi::FillDiagonalOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(fill_diagonal_grad,
phi::FillDiagonalGradOpArgumentMapping);
......@@ -835,6 +835,7 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
x.fill_diagonal_(1.0)
print(x.tolist()) #[[1.0, 2.0, 2.0], [2.0, 1.0, 2.0], [2.0, 2.0, 1.0], [2.0, 2.0, 2.0]]
"""
helper = LayerHelper("fill_diagonal_", **locals())
check_type(x, 'X', (Variable), 'fill_diagonal_')
dtype = helper.input_dtype('x')
......@@ -851,6 +852,11 @@ def fill_diagonal_(x, value, offset=0, wrap=False, name=None):
assert len(inshapeset) == 1, (
'Tensor dims should be equal while input dims > 2 in fill_diagonal_ API'
)
if in_dygraph_mode():
if len(inshape) == 2:
return _C_ops.final_state_fill_diagonal_(x, value, offset, wrap)
return _C_ops.final_state_fill_diagonal_(x, value, offset, True)
if len(inshape) == 2:
return _C_ops.fill_diagonal_(x, 'value', value, 'offset', offset,
'wrap', wrap)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册