未验证 提交 e157f2af 编写于 作者: S Siming Dai 提交者: GitHub

[Phi]Add diag_v2 grad kernel (#40447)

* Add diag grad kernel

* fix unittest case

* add float16, remove const &

* delete diag_grad in op_utils.h
上级 3149e399
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
......@@ -58,15 +56,56 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class DiagV2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "X", "X", "DiagV2Grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "DiagV2Grad");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class DiagV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("diag_v2_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(DiagGradV2NoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(diag_v2, DiagInferShapeFunctor,
PD_INFER_META(phi::DiagInferMeta));
REGISTER_OPERATOR(
diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OPERATOR(diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
ops::DiagV2GradOpMaker<paddle::framework::OpDesc>,
ops::DiagV2GradOpMaker<paddle::imperative::OpBase>,
DiagInferShapeFunctor);
REGISTER_OPERATOR(diag_v2_grad, ops::DiagV2GradOp,
ops::DiagGradV2NoNeedBufferVarsInferer);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void DiagGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
const T* dout_data = out_grad.data<T>();
auto dx_dims = x_grad->dims();
auto dout_dims = out_grad.dims();
if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
int dx_stride = phi::funcs::ComputeStride(0, dx_dims);
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
dout_data +=
(offset >= 0 ? offset * dout_stride_1 : -offset * dout_stride_0);
for (int i = 0; i < dx_length; i++) {
dx_data[i * dx_stride] = dout_data[i * (dout_stride_0 + dout_stride_1)];
}
} else {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, x_grad, static_cast<T>(0));
int dx_stride_0 = phi::funcs::ComputeStride(0, dx_dims);
int dx_stride_1 = phi::funcs::ComputeStride(1, dx_dims);
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
dx_data += (offset >= 0 ? offset * dx_stride_1 : -offset * dx_stride_0);
auto dout_length = dout_dims[0];
for (int i = 0; i < dout_length; i++) {
dx_data[i * (dx_stride_0 + dx_stride_1)] = dout_data[i * dout_stride_0];
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(diag_grad,
CPU,
ALL_LAYOUT,
phi::DiagGradKernel,
phi::dtype::float16,
int,
int64_t,
float,
double) {}
......@@ -62,5 +62,12 @@ void DiagKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
diag, CPU, ALL_LAYOUT, phi::DiagKernel, int, float, double, int64_t) {}
PD_REGISTER_KERNEL(diag,
CPU,
ALL_LAYOUT,
phi::DiagKernel,
phi::dtype::float16,
int,
float,
double,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DiagGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
// Extract the diagonal of a matrix 'dout' to a matrix 'dx'
template <typename T>
__global__ void ExtractDiagonalKernel(const T* dout,
T* dx,
std::ptrdiff_t start,
std::ptrdiff_t dx_length,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t xStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < dx_length;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t outOffset = start + sumStride * idx;
dx[xStride * idx] = dout[outOffset];
}
}
// Paste a vector 'dout' to the diagonal of a matrix 'dx'
template <typename T>
__global__ void PasteDiagonalKernel(const T* dout,
T* dx,
std::ptrdiff_t start,
std::ptrdiff_t size,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t outStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
std::ptrdiff_t xOffset = start + sumStride * idx;
dx[xOffset] = dout[outStride * idx];
}
}
template <typename T, typename Context>
void DiagGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int offset,
DenseTensor* x_grad) {
T* dx_data = dev_ctx.template Alloc<T>(x_grad);
auto* dout_data = out_grad.data<T>();
auto dx_dims = x_grad->dims();
auto dout_dims = out_grad.dims();
auto GetBlockGridSize = [&dev_ctx](int64_t size) {
const int64_t block_size =
std::min(size, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
auto size = (offset > 0) ? dx_length + offset : dx_length - offset;
int dx_stride = phi::funcs::ComputeStride(0, dx_dims);
if (size > 0) {
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
auto start =
(offset >= 0 ? offset * dout_stride_1 : -offset * dout_stride_0);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
ExtractDiagonalKernel<T><<<std::get<1>(block_grid_size),
std::get<0>(block_grid_size),
0,
dev_ctx.stream()>>>(
dout_data,
dx_data,
start,
dx_length,
dout_stride_0 + dout_stride_1,
dx_stride);
}
} else {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, x_grad, static_cast<T>(0));
int dx_stride_0 = phi::funcs::ComputeStride(0, dx_dims);
int dx_stride_1 = phi::funcs::ComputeStride(1, dx_dims);
int64_t size;
if (offset > 0) {
size = std::min(dx_dims[0], dx_dims[1] - offset);
} else {
size = std::min(dx_dims[0] + offset, dx_dims[1]);
}
if (size > 0) {
auto start = (offset >= 0 ? offset * dx_stride_1 : -offset * dx_stride_0);
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
PasteDiagonalKernel<T><<<std::get<1>(block_grid_size),
std::get<0>(block_grid_size),
0,
dev_ctx.stream()>>>(dout_data,
dx_data,
start,
size,
dx_stride_0 + dx_stride_1,
dout_stride_0);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(diag_grad,
GPU,
ALL_LAYOUT,
phi::DiagGradKernel,
phi::dtype::float16,
int,
int64_t,
float,
double) {}
......@@ -130,5 +130,12 @@ void DiagKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
diag, GPU, ALL_LAYOUT, phi::DiagKernel, int, int64_t, float, double) {}
PD_REGISTER_KERNEL(diag,
GPU,
ALL_LAYOUT,
phi::DiagKernel,
phi::dtype::float16,
int,
int64_t,
float,
double) {}
......@@ -20,8 +20,15 @@ KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("diag", {"X"}, {"offset", "padding_value"}, {"Out"});
}
KernelSignature DiagGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"diag_grad", {"X", GradVarName("Out")}, {"offset"}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(diag_v2, diag);
PD_REGISTER_BASE_KERNEL_NAME(diag_v2_grad, diag_grad);
PD_REGISTER_ARG_MAPPING_FN(diag_v2, phi::DiagOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(diag_v2_grad, phi::DiagGradOpArgumentMapping);
......@@ -44,6 +44,10 @@ class TestDiagV2Op(OpTest):
paddle.enable_static()
self.check_output(check_eager=True)
def test_check_grad(self):
paddle.enable_static()
self.check_grad(['X'], 'Out', check_eager=True)
def init_config(self):
pass
......@@ -62,14 +66,14 @@ class TestDiagV2OpCase2(TestDiagV2Op):
class TestDiagV2OpCase3(TestDiagV2Op):
def init_config(self):
self.x = np.random.randint(-10, 10, size=(10, 10))
self.x = np.random.randint(-10, 10, size=(10, 10)).astype("float64")
self.out = np.diag(self.x, self.offset)
class TestDiagV2OpCase4(TestDiagV2Op):
def init_config(self):
self.x = np.random.rand(100)
self.padding_value = 8
self.padding_value = 2
n = self.x.size
self.out = self.padding_value * np.ones((n, n)) + np.diag(
self.x, self.offset) - np.diag(self.padding_value * np.ones(n))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册