未验证 提交 783c4aba 编写于 作者: L Linjie Chen 提交者: GitHub

move diag_v2 to phi (#39914)

上级 2533cac6
......@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/diag_v2_op.h"
#include <algorithm>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -23,44 +25,6 @@ namespace operators {
class DiagV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "diag_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "diag_v2");
auto x_dims = ctx->GetInputDim("X");
auto offset = ctx->Attrs().Get<int>("offset");
if (x_dims.size() == 1UL) {
int64_t size_ = x_dims[0] + std::abs(offset);
ctx->SetOutputDim("Out", {size_, size_});
} else if (x_dims.size() == 2UL) {
int64_t size_ = 0;
if (offset >= 0) {
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if (x_dims[0] < x_dims[1] - offset) {
size_ = x_dims[0];
} else {
size_ = x_dims[1] - offset;
}
} else {
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if (x_dims[0] + offset < x_dims[1]) {
size_ = x_dims[0] + offset;
} else {
size_ = x_dims[1];
}
}
ctx->SetOutputDim("Out", {size_});
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d.",
x_dims.size()));
}
}
};
class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -94,59 +58,15 @@ class DiagV2OpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename DeviceContext, typename T>
class DiagV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* x_data = X->data<T>();
auto x_dims = X->dims();
int offset = context.Attr<int>("offset");
auto* out = context.Output<framework::Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
int64_t i;
if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
phi::funcs::SetConstant<DeviceContext, T> set_padding_value;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
const int& x_stride = ComputeStride(0, x_dims);
auto out_stride_0 = ComputeStride(0, out_dims);
auto out_stride_1 = ComputeStride(1, out_dims);
out_data +=
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
for (i = 0; i < x_length; i++) {
out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride];
}
} else {
auto out_length = out_dims[0];
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);
auto out_stride_0 = ComputeStride(0, out_dims);
x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
for (i = 0; i < out_length; i++) {
out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)];
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(diag_v2, DiagInferShapeFunctor,
PT_INFER_META(phi::DiagInferMeta));
REGISTER_OPERATOR(
diag_v2, ops::DiagV2Op, ops::DiagV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
diag_v2, ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::DiagV2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DiagInferShapeFunctor);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <tuple>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_v2_op.h"
namespace paddle {
namespace operators {
// Extract the diagonal of a matrix 'x' to a vector 'out'.
template <typename T>
__global__ void ExtractDiagonalKernel(T* out, const T* x, std::ptrdiff_t start,
std::ptrdiff_t size,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t outStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t xOffset = start + sumStride * idx;
out[outStride * idx] = x[xOffset];
}
}
// Paste a vector 'x' to the diagonal of a matrix 'out'
template <typename T>
__global__ void PasteDiagonalKernel(T* out, const T* x, std::ptrdiff_t start,
std::ptrdiff_t x_length,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t xStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < x_length; idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t outOffset = start + sumStride * idx;
out[outOffset] = x[xStride * idx];
}
}
template <typename DeviceContext, typename T>
class DiagV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* x_data = X->data<T>();
auto x_dims = X->dims();
int offset = context.Attr<int>("offset");
auto* out = context.Output<framework::Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_dims = out->dims();
auto& dev_ctx = context.template device_context<DeviceContext>();
auto GetBlockGridSize = [&dev_ctx](int64_t size) {
const int64_t block_size =
std::min(size, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (x_dims.size() == 1) {
float padding_value = context.Attr<float>("padding_value");
phi::funcs::SetConstant<DeviceContext, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = ComputeStride(0, x_dims);
if (size > 0) {
const auto& out_stride_0 = ComputeStride(0, out_dims);
const auto& out_stride_1 = ComputeStride(1, out_dims);
auto start =
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
PasteDiagonalKernel<
T><<<std::get<1>(block_grid_size), std::get<0>(block_grid_size), 0,
dev_ctx.stream()>>>(out_data, x_data, start, x_length,
out_stride_0 + out_stride_1, x_stride);
}
} else {
const int& x_stride_0 = ComputeStride(0, x_dims);
const int& x_stride_1 = ComputeStride(1, x_dims);
int64_t size;
if (offset > 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
size = std::min(x_dims[0] + offset, x_dims[1]);
}
if (size > 0) {
auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
const auto& out_stride_0 = ComputeStride(0, out_dims);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
ExtractDiagonalKernel<
T><<<std::get<1>(block_grid_size), std::get<0>(block_grid_size), 0,
dev_ctx.stream()>>>(out_data, x_data, start, size,
x_stride_0 + x_stride_1, out_stride_0);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
diag_v2, ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::DiagV2CUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using DDim = framework::DDim;
static inline int ComputeStride(int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
} // namespace operators
} // namespace paddle
......@@ -37,7 +37,8 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly.
*/
const std::unordered_set<std::string> deprecated_op_names({"flatten",
const std::unordered_set<std::string> deprecated_op_names({"diag",
"flatten",
"flatten_grad",
"matmul",
"matmul_grad",
......
......@@ -310,6 +310,7 @@ void BCELossInferMeta(const MetaTensor& input,
}
out->set_dims(input_dims);
out->set_dtype(input.dtype());
out->share_lod(input);
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/phi/infermeta/unary.h"
#include <algorithm>
#include <set>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
......@@ -715,6 +716,45 @@ void UnfoldInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims));
}
void DiagInferMeta(const MetaTensor& x,
int offset,
float padding_value,
MetaTensor* out) {
auto x_dims = x.dims();
if (x_dims.size() == 1UL) {
int64_t size_ = x_dims[0] + std::abs(offset);
out->set_dims({size_, size_});
out->set_dtype(x.dtype());
} else if (x_dims.size() == 2UL) {
int64_t size_ = 0;
if (offset >= 0) {
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if (x_dims[0] < x_dims[1] - offset) {
size_ = x_dims[0];
} else {
size_ = x_dims[1] - offset;
}
} else {
// Note(LutaoChu): Do not use std::min here, otherwise the calculation
// of `size_` will have unexpected result on Windows Python3.8
if (x_dims[0] + offset < x_dims[1]) {
size_ = x_dims[0] + offset;
} else {
size_ = x_dims[1];
}
}
out->set_dims({size_});
out->set_dtype(x.dtype());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d.",
x_dims.size()));
}
}
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
......@@ -104,4 +104,9 @@ void UnfoldInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void DiagInferMeta(const MetaTensor& x,
int offset,
float padding_value,
MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void DiagKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
float padding_value,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto x_dims = x.dims();
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
int64_t i;
if (x_dims.size() == 1) {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
const int& x_stride = phi::funcs::ComputeStride(0, x_dims);
auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
auto out_stride_1 = phi::funcs::ComputeStride(1, out_dims);
out_data += (offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
for (i = 0; i < x_length; i++) {
out_data[i * (out_stride_0 + out_stride_1)] = x_data[i * x_stride];
}
} else {
auto out_length = out_dims[0];
const int& x_stride_0 = phi::funcs::ComputeStride(0, x_dims);
const int& x_stride_1 = phi::funcs::ComputeStride(1, x_dims);
auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
x_data += (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
for (i = 0; i < out_length; i++) {
out_data[i * out_stride_0] = x_data[i * (x_stride_0 + x_stride_1)];
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
diag, CPU, ALL_LAYOUT, phi::DiagKernel, int, float, double, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DiagKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
float padding_value,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
inline int ComputeStride(int axis, phi::DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_kernel.h"
#include <algorithm>
#include <tuple>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
// Extract the diagonal of a matrix 'x' to a vector 'out'.
template <typename T>
__global__ void ExtractDiagonalKernel(T* out,
const T* x,
std::ptrdiff_t start,
std::ptrdiff_t size,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t outStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t xOffset = start + sumStride * idx;
out[outStride * idx] = x[xOffset];
}
}
// Paste a vector 'x' to the diagonal of a matrix 'out'
template <typename T>
__global__ void PasteDiagonalKernel(T* out,
const T* x,
std::ptrdiff_t start,
std::ptrdiff_t x_length,
const std::ptrdiff_t sumStride,
const std::ptrdiff_t xStride) {
for (std::ptrdiff_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < x_length;
idx += gridDim.x * blockDim.x) {
const std::ptrdiff_t outOffset = start + sumStride * idx;
out[outOffset] = x[xStride * idx];
}
}
template <typename T, typename Context>
void DiagKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
float padding_value,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto x_dims = x.dims();
T* out_data = dev_ctx.template Alloc<T>(out);
auto out_dims = out->dims();
auto GetBlockGridSize = [&dev_ctx](int64_t size) {
const int64_t block_size =
std::min(size, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};
if (x_dims.size() == 1) {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));
auto x_length = x_dims[0];
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = phi::funcs::ComputeStride(0, x_dims);
if (size > 0) {
const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
const auto& out_stride_1 = phi::funcs::ComputeStride(1, out_dims);
auto start =
(offset >= 0 ? offset * out_stride_1 : -offset * out_stride_0);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
PasteDiagonalKernel<T><<<std::get<1>(block_grid_size),
std::get<0>(block_grid_size),
0,
dev_ctx.stream()>>>(out_data,
x_data,
start,
x_length,
out_stride_0 + out_stride_1,
x_stride);
}
} else {
const int& x_stride_0 = phi::funcs::ComputeStride(0, x_dims);
const int& x_stride_1 = phi::funcs::ComputeStride(1, x_dims);
int64_t size;
if (offset > 0) {
size = std::min(x_dims[0], x_dims[1] - offset);
} else {
size = std::min(x_dims[0] + offset, x_dims[1]);
}
if (size > 0) {
auto start = (offset >= 0 ? offset * x_stride_1 : -offset * x_stride_0);
const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
std::tuple<int64_t, int64_t> block_grid_size = GetBlockGridSize(size);
ExtractDiagonalKernel<T><<<std::get<1>(block_grid_size),
std::get<0>(block_grid_size),
0,
dev_ctx.stream()>>>(
out_data, x_data, start, size, x_stride_0 + x_stride_1, out_stride_0);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
diag, GPU, ALL_LAYOUT, phi::DiagKernel, int, int64_t, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DiagOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("diag", {"X"}, {"offset", "padding_value"}, {"Out"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(diag_v2, diag);
PD_REGISTER_ARG_MAPPING_FN(diag_v2, phi::DiagOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册