From 41f11d29526b2a3827a1a5224bc00ebe540e34d4 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Wed, 20 Jul 2022 17:15:04 +0800 Subject: [PATCH] [PHI] move diag_embed op to phi. (#44408) * move diag_embed to phi. --- paddle/fluid/operators/diag_embed_op.cc | 93 ++----------- paddle/fluid/operators/diag_embed_op.cu | 30 ---- paddle/fluid/operators/diag_embed_op.h | 130 ------------------ paddle/phi/api/yaml/legacy_api.yaml | 8 ++ paddle/phi/infermeta/unary.cc | 63 +++++++++ paddle/phi/infermeta/unary.h | 3 + paddle/phi/kernels/cpu/diag_embed_kernel.cc | 28 ++++ paddle/phi/kernels/diag_embed_kernel.h | 29 ++++ paddle/phi/kernels/gpu/diag_embed_kernel.cu | 28 ++++ paddle/phi/kernels/impl/diag_embed_impl.h | 129 +++++++++++++++++ .../fluid/tests/unittests/test_diag_embed.py | 3 +- python/paddle/nn/functional/extension.py | 15 +- 12 files changed, 310 insertions(+), 249 deletions(-) delete mode 100644 paddle/fluid/operators/diag_embed_op.cu delete mode 100644 paddle/fluid/operators/diag_embed_op.h create mode 100644 paddle/phi/kernels/cpu/diag_embed_kernel.cc create mode 100644 paddle/phi/kernels/diag_embed_kernel.h create mode 100644 paddle/phi/kernels/gpu/diag_embed_kernel.cu create mode 100644 paddle/phi/kernels/impl/diag_embed_impl.h diff --git a/paddle/fluid/operators/diag_embed_op.cc b/paddle/fluid/operators/diag_embed_op.cc index 531d6f92d88..0dc5d024ec4 100644 --- a/paddle/fluid/operators/diag_embed_op.cc +++ b/paddle/fluid/operators/diag_embed_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/diag_embed_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,81 +23,6 @@ namespace operators { class DiagEmbedOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("Input"), - true, - platform::errors::NotFound("Input of DiagEmbedOp is not found.")); - - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), - true, - platform::errors::NotFound("Output of DiagEmbedOp is not found.")); - - int offset = ctx->Attrs().Get("offset"); - int dim1 = ctx->Attrs().Get("dim1"); - int dim2 = ctx->Attrs().Get("dim2"); - - auto x_dims = ctx->GetInputDim("Input"); - - PADDLE_ENFORCE_GE( - dim1, - -(x_dims.size() + 1), - platform::errors::OutOfRange( - "Dim1 is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size() + 1), - x_dims.size(), - dim1)); - PADDLE_ENFORCE_LE( - dim1, - x_dims.size(), - platform::errors::OutOfRange( - "Dim1 is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size() + 1), - x_dims.size(), - dim1)); - - PADDLE_ENFORCE_GE( - dim2, - -(x_dims.size() + 1), - platform::errors::OutOfRange( - "Dim2 is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size() + 1), - x_dims.size(), - dim2)); - PADDLE_ENFORCE_LE( - dim2, - x_dims.size(), - platform::errors::OutOfRange( - "Dim2 is out of range (expected to be in range of [%ld, " - "%ld], but got %ld).", - -(x_dims.size() + 1), - x_dims.size(), - dim2)); - - int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; - int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; - int offset_ = std::abs(offset); - - PADDLE_ENFORCE_NE(dim1_, - dim2_, - platform::errors::InvalidArgument( - "diagonal dimensions should not be identical " - "%ld vs %ld.", - dim1, - dim2)); - - int new_dim_len = offset_ + x_dims[x_dims.size() - 1]; - auto sizes = vectorize(x_dims); - sizes.pop_back(); - sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len); - sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len); - ctx->SetOutputDim("Out", phi::make_ddim(sizes)); - } }; class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker { @@ -131,15 +59,14 @@ class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -namespace platform = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(diag_embed, + DiagEmbedInferShapeFunctor, + PD_INFER_META(phi::DiagEmbedInferMeta)); + REGISTER_OPERATOR( diag_embed, ops::DiagEmbedOp, ops::DiagEmbedOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(diag_embed, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel); + paddle::framework::EmptyGradOpMaker, + DiagEmbedInferShapeFunctor); diff --git a/paddle/fluid/operators/diag_embed_op.cu b/paddle/fluid/operators/diag_embed_op.cu deleted file mode 100644 index e0f8c16731f..00000000000 --- a/paddle/fluid/operators/diag_embed_op.cu +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/diag_embed_op.h" - -namespace ops = paddle::operators; -namespace platform = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - diag_embed, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel, - ops::DiagEmbedKernel); diff --git a/paddle/fluid/operators/diag_embed_op.h b/paddle/fluid/operators/diag_embed_op.h deleted file mode 100644 index 94c479bb452..00000000000 --- a/paddle/fluid/operators/diag_embed_op.h +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -struct DiagEmbedFunctor { - DiagEmbedFunctor(const T* input, - int64_t numel, - const int64_t* dim, - int64_t offset, - int64_t dims_size, - T* output, - const int64_t* strides) - : input_(input), - numel_(numel), - dim_(dim), - offset_(offset), - dims_size_(dims_size), - output_(output), - strides_(strides) {} - - HOSTDEVICE void operator()(size_t idx) const { - int64_t position = 0; - auto numel = numel_; - int64_t num = idx; - for (int64_t i = 0; i < dims_size_; i++) { - numel = numel / dim_[i]; - position += num / numel * strides_[i]; - num = num % numel; - } - output_[position + offset_] = input_[idx]; - } - - const T* input_; - int64_t numel_; - const int64_t* dim_; - int64_t offset_; - int64_t dims_size_; - T* output_; - const int64_t* strides_; -}; - -template -class DiagEmbedKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* out = context.Output("Out"); - - const int64_t offset = context.Attr("offset"); - const int64_t dim1 = context.Attr("dim1"); - const int64_t dim2 = context.Attr("dim2"); - auto* input_data = input->data(); - - T* out_data = out->mutable_data(context.GetPlace()); - phi::funcs::SetConstant set_zero; - auto& dev_ctx = context.template device_context(); - set_zero(dev_ctx, out, static_cast(0.0)); - - auto out_dims = out->dims(); - int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1; - int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2; - auto stride = phi::stride(out_dims); - int64_t diag_size; - int64_t storage_offset = 0; - if (offset >= 0) { - int64_t dim = out_dims[dim2_] - offset; - diag_size = std::max(std::min(out_dims[dim1_], dim), 0); - } else { - int64_t dim = out_dims[dim1_] + offset; - diag_size = std::max(std::min(dim, out_dims[dim2_]), 0); - } - if (diag_size == 0) { - // skip - } else if (offset >= 0) { - storage_offset += offset * stride[dim2_]; - } else { - storage_offset -= offset * stride[dim1_]; - } - auto strides = vectorize(stride); - strides.erase(strides.begin() + std::max(dim1_, dim2_)); - strides.erase(strides.begin() + std::min(dim1_, dim2_)); - strides.push_back(stride[dim1_] + stride[dim2_]); - const auto dims = vectorize(input->dims()); - -#if defined(__NVCC__) || defined(__HIPCC__) - thrust::device_vector dims_vec(dims); - const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data()); - thrust::device_vector strides_vec(strides); - const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data()); -#else - const int64_t* dims_arr = dims.data(); - const int64_t* strides_arr = strides.data(); -#endif - - platform::ForRange for_range(dev_ctx, input->numel()); - DiagEmbedFunctor functor(input_data, - input->numel(), - dims_arr, - storage_offset, - dims.size(), - out_data, - strides_arr); - for_range(functor); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index f60309985a6..40fbdc9a917 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -524,6 +524,14 @@ func : determinant backward : det_grad +- api : diag_embed + args : (Tensor x, int offset, int dim1, int dim2) + output : Tensor + infer_meta : + func : DiagEmbedInferMeta + kernel : + func : diag_embed + - api : divide args : (Tensor x, Tensor y) output : Tensor diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c39fb96430f..7b1c6dfe65a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -288,6 +288,69 @@ void CumInferMeta(const MetaTensor& x, out->share_lod(x); } +void DiagEmbedInferMeta( + const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) { + auto x_dims = x.dims(); + + PADDLE_ENFORCE_GE( + dim1, + -(x_dims.size() + 1), + phi::errors::OutOfRange( + "Dim1 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), + x_dims.size(), + dim1)); + PADDLE_ENFORCE_LE( + dim1, + x_dims.size(), + phi::errors::OutOfRange( + "Dim1 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), + x_dims.size(), + dim1)); + + PADDLE_ENFORCE_GE( + dim2, + -(x_dims.size() + 1), + phi::errors::OutOfRange( + "Dim2 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), + x_dims.size(), + dim2)); + PADDLE_ENFORCE_LE( + dim2, + x_dims.size(), + phi::errors::OutOfRange( + "Dim2 is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size() + 1), + x_dims.size(), + dim2)); + + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; + int offset_ = std::abs(offset); + + PADDLE_ENFORCE_NE(dim1_, + dim2_, + phi::errors::InvalidArgument( + "diagonal dimensions should not be identical " + "%ld vs %ld.", + dim1, + dim2)); + + int new_dim_len = offset_ + x_dims[x_dims.size() - 1]; + auto sizes = vectorize(x_dims); + sizes.pop_back(); + sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len); + sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len); + out->set_dims(phi::make_ddim(sizes)); + out->set_dtype(x.dtype()); +} + void DiagInferMeta(const MetaTensor& x, int offset, float padding_value, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 691fc8ff41c..e825ba98f44 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -71,6 +71,9 @@ void CumInferMeta(const MetaTensor& x, bool reverse, MetaTensor* out); +void DiagEmbedInferMeta( + const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out); + void DiagInferMeta(const MetaTensor& x, int offset, float padding_value, diff --git a/paddle/phi/kernels/cpu/diag_embed_kernel.cc b/paddle/phi/kernels/cpu/diag_embed_kernel.cc new file mode 100644 index 00000000000..714b53c6919 --- /dev/null +++ b/paddle/phi/kernels/cpu/diag_embed_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/diag_embed_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/diag_embed_impl.h" + +PD_REGISTER_KERNEL(diag_embed, + CPU, + ALL_LAYOUT, + phi::DiagEmbedKernel, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/diag_embed_kernel.h b/paddle/phi/kernels/diag_embed_kernel.h new file mode 100644 index 00000000000..e47eab82474 --- /dev/null +++ b/paddle/phi/kernels/diag_embed_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DiagEmbedKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + int dim1, + int dim2, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/diag_embed_kernel.cu b/paddle/phi/kernels/gpu/diag_embed_kernel.cu new file mode 100644 index 00000000000..ece0f012e62 --- /dev/null +++ b/paddle/phi/kernels/gpu/diag_embed_kernel.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/diag_embed_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/diag_embed_impl.h" + +PD_REGISTER_KERNEL(diag_embed, + GPU, + ALL_LAYOUT, + phi::DiagEmbedKernel, + int, + int64_t, + float, + double) {} diff --git a/paddle/phi/kernels/impl/diag_embed_impl.h b/paddle/phi/kernels/impl/diag_embed_impl.h new file mode 100644 index 00000000000..a4430fde923 --- /dev/null +++ b/paddle/phi/kernels/impl/diag_embed_impl.h @@ -0,0 +1,129 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(__NVCC__) || defined(__HIPCC__) +#include +#include +#endif + +#include "paddle/phi/kernels/diag_embed_kernel.h" + +#include + +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct DiagEmbedFunctor { + DiagEmbedFunctor(const T* input, + int64_t numel, + const int64_t* dim, + int64_t offset, + int64_t dims_size, + T* output, + const int64_t* strides) + : input_(input), + numel_(numel), + dim_(dim), + offset_(offset), + dims_size_(dims_size), + output_(output), + strides_(strides) {} + + HOSTDEVICE void operator()(size_t idx) const { + int64_t position = 0; + auto numel = numel_; + int64_t num = idx; + for (int64_t i = 0; i < dims_size_; i++) { + numel = numel / dim_[i]; + position += num / numel * strides_[i]; + num = num % numel; + } + output_[position + offset_] = input_[idx]; + } + + const T* input_; + int64_t numel_; + const int64_t* dim_; + int64_t offset_; + int64_t dims_size_; + T* output_; + const int64_t* strides_; +}; + +template +void DiagEmbedKernel(const Context& dev_ctx, + const DenseTensor& x, + int offset, + int dim1, + int dim2, + DenseTensor* out) { + auto* input_data = x.data(); + T* out_data = dev_ctx.template Alloc(out); + phi::funcs::SetConstant set_zero; + + set_zero(dev_ctx, out, static_cast(0.0)); + + auto out_dims = out->dims(); + int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1; + int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2; + auto stride = phi::stride(out_dims); + int64_t diag_size; + int64_t storage_offset = 0; + if (offset >= 0) { + int64_t dim = out_dims[dim2_] - offset; + diag_size = std::max(std::min(out_dims[dim1_], dim), 0); + } else { + int64_t dim = out_dims[dim1_] + offset; + diag_size = std::max(std::min(dim, out_dims[dim2_]), 0); + } + if (diag_size == 0) { + // skip + } else if (offset >= 0) { + storage_offset += offset * stride[dim2_]; + } else { + storage_offset -= offset * stride[dim1_]; + } + auto strides = vectorize(stride); + strides.erase(strides.begin() + std::max(dim1_, dim2_)); + strides.erase(strides.begin() + std::min(dim1_, dim2_)); + strides.push_back(stride[dim1_] + stride[dim2_]); + const auto dims = vectorize(x.dims()); + +#if defined(__NVCC__) || defined(__HIPCC__) + thrust::device_vector dims_vec(dims); + const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data()); + thrust::device_vector strides_vec(strides); + const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data()); +#else + const int64_t* dims_arr = dims.data(); + const int64_t* strides_arr = strides.data(); +#endif + + phi::funcs::ForRange for_range(dev_ctx, x.numel()); + DiagEmbedFunctor functor(input_data, + x.numel(), + dims_arr, + storage_offset, + dims.size(), + out_data, + strides_arr); + for_range(functor); +} + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_diag_embed.py b/python/paddle/fluid/tests/unittests/test_diag_embed.py index c7f933d23ea..546247167b8 100644 --- a/python/paddle/fluid/tests/unittests/test_diag_embed.py +++ b/python/paddle/fluid/tests/unittests/test_diag_embed.py @@ -27,11 +27,12 @@ class TestDiagEmbedOp(OpTest): def setUp(self): self.op_type = "diag_embed" + self.python_api = F.diag_embed self.init_config() self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def init_config(self): self.case = np.random.randn(2, 3).astype('float32') diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 27bc2ef70bc..1bfa7f14883 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -98,12 +98,18 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1): # [[ 0. , 0. , 0. , 0. ], # [ 0. , 0. , 0. , 0. ]]] """ - inputs = {'Input': [input]} - attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} - if not isinstance(input, Variable): input = assign(input) + if in_dygraph_mode(): + return _C_ops.final_state_diag_embed(input, offset, dim1, dim2) + elif in_dynamic_mode(): + return _C_ops.diag_embed(input, "offset", offset, "dim1", dim1, "dim2", + dim2) + + inputs = {'Input': [input]} + attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} + def __check_input(input, offset, dim1, dim2): check_dtype(input.dtype, 'Input', ['int32', 'int64', 'float16', 'float32', 'float64'], @@ -129,8 +135,7 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1): "dim1 and dim2 cannot be the same dimension." \ "But received dim1 = %d, dim2 = %d\n"%(dim1, dim2) - if not in_dynamic_mode(): - __check_input(input, offset, dim1, dim2) + __check_input(input, offset, dim1, dim2) helper = LayerHelper("diag_embed", **locals()) out = helper.create_variable_for_type_inference(dtype=input.dtype) -- GitLab