You need to sign in or sign up before continuing.
未验证 提交 41f11d29 编写于 作者: Z Zhong Hui 提交者: GitHub

[PHI] move diag_embed op to phi. (#44408)

* move diag_embed to phi.
上级 889bdde3
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/diag_embed_op.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,81 +23,6 @@ namespace operators { ...@@ -20,81 +23,6 @@ namespace operators {
class DiagEmbedOp : public framework::OperatorWithKernel { class DiagEmbedOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"),
true,
platform::errors::NotFound("Input of DiagEmbedOp is not found."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"),
true,
platform::errors::NotFound("Output of DiagEmbedOp is not found."));
int offset = ctx->Attrs().Get<int>("offset");
int dim1 = ctx->Attrs().Get<int>("dim1");
int dim2 = ctx->Attrs().Get<int>("dim2");
auto x_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_GE(
dim1,
-(x_dims.size() + 1),
platform::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));
PADDLE_ENFORCE_LE(
dim1,
x_dims.size(),
platform::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));
PADDLE_ENFORCE_GE(
dim2,
-(x_dims.size() + 1),
platform::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));
PADDLE_ENFORCE_LE(
dim2,
x_dims.size(),
platform::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
int offset_ = std::abs(offset);
PADDLE_ENFORCE_NE(dim1_,
dim2_,
platform::errors::InvalidArgument(
"diagonal dimensions should not be identical "
"%ld vs %ld.",
dim1,
dim2));
int new_dim_len = offset_ + x_dims[x_dims.size() - 1];
auto sizes = vectorize(x_dims);
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len);
ctx->SetOutputDim("Out", phi::make_ddim(sizes));
}
}; };
class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker { class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -131,15 +59,14 @@ class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -131,15 +59,14 @@ class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace platform = paddle::platform; DECLARE_INFER_SHAPE_FUNCTOR(diag_embed,
DiagEmbedInferShapeFunctor,
PD_INFER_META(phi::DiagEmbedInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
diag_embed, diag_embed,
ops::DiagEmbedOp, ops::DiagEmbedOp,
ops::DiagEmbedOpMaker, ops::DiagEmbedOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OP_CPU_KERNEL(diag_embed, DiagEmbedInferShapeFunctor);
ops::DiagEmbedKernel<phi::CPUContext, int>,
ops::DiagEmbedKernel<phi::CPUContext, float>,
ops::DiagEmbedKernel<phi::CPUContext, double>,
ops::DiagEmbedKernel<phi::CPUContext, int64_t>);
...@@ -524,6 +524,14 @@ ...@@ -524,6 +524,14 @@
func : determinant func : determinant
backward : det_grad backward : det_grad
- api : diag_embed
args : (Tensor x, int offset, int dim1, int dim2)
output : Tensor
infer_meta :
func : DiagEmbedInferMeta
kernel :
func : diag_embed
- api : divide - api : divide
args : (Tensor x, Tensor y) args : (Tensor x, Tensor y)
output : Tensor output : Tensor
......
...@@ -288,6 +288,69 @@ void CumInferMeta(const MetaTensor& x, ...@@ -288,6 +288,69 @@ void CumInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
void DiagEmbedInferMeta(
const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(
dim1,
-(x_dims.size() + 1),
phi::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));
PADDLE_ENFORCE_LE(
dim1,
x_dims.size(),
phi::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));
PADDLE_ENFORCE_GE(
dim2,
-(x_dims.size() + 1),
phi::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));
PADDLE_ENFORCE_LE(
dim2,
x_dims.size(),
phi::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));
int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
int offset_ = std::abs(offset);
PADDLE_ENFORCE_NE(dim1_,
dim2_,
phi::errors::InvalidArgument(
"diagonal dimensions should not be identical "
"%ld vs %ld.",
dim1,
dim2));
int new_dim_len = offset_ + x_dims[x_dims.size() - 1];
auto sizes = vectorize(x_dims);
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len);
out->set_dims(phi::make_ddim(sizes));
out->set_dtype(x.dtype());
}
void DiagInferMeta(const MetaTensor& x, void DiagInferMeta(const MetaTensor& x,
int offset, int offset,
float padding_value, float padding_value,
......
...@@ -71,6 +71,9 @@ void CumInferMeta(const MetaTensor& x, ...@@ -71,6 +71,9 @@ void CumInferMeta(const MetaTensor& x,
bool reverse, bool reverse,
MetaTensor* out); MetaTensor* out);
void DiagEmbedInferMeta(
const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out);
void DiagInferMeta(const MetaTensor& x, void DiagInferMeta(const MetaTensor& x,
int offset, int offset,
float padding_value, float padding_value,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,19 +12,17 @@ ...@@ -12,19 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <thrust/device_vector.h> #include "paddle/phi/kernels/diag_embed_kernel.h"
#include <thrust/host_vector.h>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/fluid/operators/diag_embed_op.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/diag_embed_impl.h"
namespace ops = paddle::operators; PD_REGISTER_KERNEL(diag_embed,
namespace platform = paddle::platform; CPU,
REGISTER_OP_CUDA_KERNEL( ALL_LAYOUT,
diag_embed, phi::DiagEmbedKernel,
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, int>, int,
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, int64_t>, int64_t,
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, float>, float,
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, double) {}
platform::float16>,
ops::DiagEmbedKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DiagEmbedKernel(const Context& dev_ctx,
const DenseTensor& x,
int offset,
int dim1,
int dim2,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/diag_embed_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/diag_embed_impl.h"
PD_REGISTER_KERNEL(diag_embed,
GPU,
ALL_LAYOUT,
phi::DiagEmbedKernel,
int,
int64_t,
float,
double) {}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,15 +14,19 @@ ...@@ -14,15 +14,19 @@
#pragma once #pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include "paddle/phi/kernels/diag_embed_kernel.h"
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T> template <typename T>
struct DiagEmbedFunctor { struct DiagEmbedFunctor {
...@@ -62,69 +66,64 @@ struct DiagEmbedFunctor { ...@@ -62,69 +66,64 @@ struct DiagEmbedFunctor {
const int64_t* strides_; const int64_t* strides_;
}; };
template <typename DeviceContext, typename T> template <typename T, typename Context>
class DiagEmbedKernel : public framework::OpKernel<T> { void DiagEmbedKernel(const Context& dev_ctx,
public: const DenseTensor& x,
void Compute(const framework::ExecutionContext& context) const override { int offset,
auto* input = context.Input<framework::Tensor>("Input"); int dim1,
auto* out = context.Output<framework::Tensor>("Out"); int dim2,
DenseTensor* out) {
const int64_t offset = context.Attr<int>("offset"); auto* input_data = x.data<T>();
const int64_t dim1 = context.Attr<int>("dim1"); T* out_data = dev_ctx.template Alloc<T>(out);
const int64_t dim2 = context.Attr<int>("dim2"); phi::funcs::SetConstant<Context, T> set_zero;
auto* input_data = input->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace()); set_zero(dev_ctx, out, static_cast<T>(0.0));
phi::funcs::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, out, static_cast<T>(0.0));
auto out_dims = out->dims(); auto out_dims = out->dims();
int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1; int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1;
int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2; int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2;
auto stride = phi::stride(out_dims); auto stride = phi::stride(out_dims);
int64_t diag_size; int64_t diag_size;
int64_t storage_offset = 0; int64_t storage_offset = 0;
if (offset >= 0) { if (offset >= 0) {
int64_t dim = out_dims[dim2_] - offset; int64_t dim = out_dims[dim2_] - offset;
diag_size = std::max<int64_t>(std::min(out_dims[dim1_], dim), 0); diag_size = std::max<int64_t>(std::min(out_dims[dim1_], dim), 0);
} else { } else {
int64_t dim = out_dims[dim1_] + offset; int64_t dim = out_dims[dim1_] + offset;
diag_size = std::max<int64_t>(std::min(dim, out_dims[dim2_]), 0); diag_size = std::max<int64_t>(std::min(dim, out_dims[dim2_]), 0);
} }
if (diag_size == 0) { if (diag_size == 0) {
// skip // skip
} else if (offset >= 0) { } else if (offset >= 0) {
storage_offset += offset * stride[dim2_]; storage_offset += offset * stride[dim2_];
} else { } else {
storage_offset -= offset * stride[dim1_]; storage_offset -= offset * stride[dim1_];
} }
auto strides = vectorize(stride); auto strides = vectorize(stride);
strides.erase(strides.begin() + std::max(dim1_, dim2_)); strides.erase(strides.begin() + std::max(dim1_, dim2_));
strides.erase(strides.begin() + std::min(dim1_, dim2_)); strides.erase(strides.begin() + std::min(dim1_, dim2_));
strides.push_back(stride[dim1_] + stride[dim2_]); strides.push_back(stride[dim1_] + stride[dim2_]);
const auto dims = vectorize(input->dims()); const auto dims = vectorize(x.dims());
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
thrust::device_vector<int64_t> dims_vec(dims); thrust::device_vector<int64_t> dims_vec(dims);
const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data()); const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data());
thrust::device_vector<int64_t> strides_vec(strides); thrust::device_vector<int64_t> strides_vec(strides);
const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data()); const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data());
#else #else
const int64_t* dims_arr = dims.data(); const int64_t* dims_arr = dims.data();
const int64_t* strides_arr = strides.data(); const int64_t* strides_arr = strides.data();
#endif #endif
platform::ForRange<DeviceContext> for_range(dev_ctx, input->numel()); phi::funcs::ForRange<Context> for_range(dev_ctx, x.numel());
DiagEmbedFunctor<T> functor(input_data, DiagEmbedFunctor<T> functor(input_data,
input->numel(), x.numel(),
dims_arr, dims_arr,
storage_offset, storage_offset,
dims.size(), dims.size(),
out_data, out_data,
strides_arr); strides_arr);
for_range(functor); for_range(functor);
} }
};
} // namespace operators } // namespace phi
} // namespace paddle
...@@ -27,11 +27,12 @@ class TestDiagEmbedOp(OpTest): ...@@ -27,11 +27,12 @@ class TestDiagEmbedOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "diag_embed" self.op_type = "diag_embed"
self.python_api = F.diag_embed
self.init_config() self.init_config()
self.outputs = {'Out': self.target} self.outputs = {'Out': self.target}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def init_config(self): def init_config(self):
self.case = np.random.randn(2, 3).astype('float32') self.case = np.random.randn(2, 3).astype('float32')
......
...@@ -98,12 +98,18 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1): ...@@ -98,12 +98,18 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
# [[ 0. , 0. , 0. , 0. ], # [[ 0. , 0. , 0. , 0. ],
# [ 0. , 0. , 0. , 0. ]]] # [ 0. , 0. , 0. , 0. ]]]
""" """
inputs = {'Input': [input]}
attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2}
if not isinstance(input, Variable): if not isinstance(input, Variable):
input = assign(input) input = assign(input)
if in_dygraph_mode():
return _C_ops.final_state_diag_embed(input, offset, dim1, dim2)
elif in_dynamic_mode():
return _C_ops.diag_embed(input, "offset", offset, "dim1", dim1, "dim2",
dim2)
inputs = {'Input': [input]}
attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2}
def __check_input(input, offset, dim1, dim2): def __check_input(input, offset, dim1, dim2):
check_dtype(input.dtype, 'Input', check_dtype(input.dtype, 'Input',
['int32', 'int64', 'float16', 'float32', 'float64'], ['int32', 'int64', 'float16', 'float32', 'float64'],
...@@ -129,8 +135,7 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1): ...@@ -129,8 +135,7 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
"dim1 and dim2 cannot be the same dimension." \ "dim1 and dim2 cannot be the same dimension." \
"But received dim1 = %d, dim2 = %d\n"%(dim1, dim2) "But received dim1 = %d, dim2 = %d\n"%(dim1, dim2)
if not in_dynamic_mode(): __check_input(input, offset, dim1, dim2)
__check_input(input, offset, dim1, dim2)
helper = LayerHelper("diag_embed", **locals()) helper = LayerHelper("diag_embed", **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册