diff --git a/paddle/fluid/operators/inverse_op.cc b/paddle/fluid/operators/inverse_op.cc index e93ca5ad54035d16e64a85ebb0ddc1147e9be6b5..5d0c7c754b26cd8ceff404bfb9669a883fc6152b 100644 --- a/paddle/fluid/operators/inverse_op.cc +++ b/paddle/fluid/operators/inverse_op.cc @@ -12,57 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/inverse_op.h" - #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + namespace paddle { namespace operators { class InverseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Inverse"); - OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Inverse"); - - auto input_dims = ctx->GetInputDim("Input"); - int64_t input_rank = input_dims.size(); - PADDLE_ENFORCE_GE( - input_rank, - 2, - platform::errors::InvalidArgument( - "The dimension of Input(Input) is expected to be no less than 2. " - "But received: Input(Input)'s dimension = %d, shape = [%s].", - input_rank, - input_dims)); - for (int64_t i = 0; i < input_rank; ++i) { - PADDLE_ENFORCE_EQ( - (input_dims[i] == -1) || (input_dims[i] > 0), - true, - platform::errors::InvalidArgument( - "Each dimension of input tensor is expected to be -1 or a " - "positive number, but received %d. Input's shape is [%s].", - input_dims[i], - input_dims)); - } - if (input_dims[input_rank - 2] > 0 && input_dims[input_rank - 1] > 0) { - PADDLE_ENFORCE_EQ(input_dims[input_rank - 2], - input_dims[input_rank - 1], - platform::errors::InvalidArgument( - "The last two dimensions are expected to be equal. " - "But received: %d and %d; " - "Input(Input)'s shape = [%s].", - input_dims[input_rank - 2], - input_dims[input_rank - 1], - input_dims)); - } - - ctx->SetOutputDim("Output", input_dims); - ctx->ShareLoD("Input", /*->*/ "Output"); - } }; class InverseOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { @@ -78,19 +44,6 @@ class InverseOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class InverseGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - auto input_grad = framework::GradVarName("Input"); - auto output_grad = framework::GradVarName("Output"); - - OP_INOUT_CHECK(ctx->HasInput("Output"), "Input", "Output", "InverseGrad"); - OP_INOUT_CHECK( - ctx->HasInput(output_grad), "Input", output_grad, "InverseGrad"); - - if (ctx->HasOutput(input_grad)) { - ctx->SetOutputDim(input_grad, ctx->GetInputDim(output_grad)); - } - } }; class InverseOpMaker : public framework::OpProtoAndCheckerMaker { @@ -128,18 +81,23 @@ class InverseGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(inverse, + InverseInferShapeFunctor, + PD_INFER_META(phi::InverseInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(inverse_grad, + InverseGradInferShapeFunctor, + PD_INFER_META(phi::InverseGradInferMeta)); + REGISTER_OPERATOR(inverse, ops::InverseOp, ops::InverseOpMaker, ops::InverseOpInferVarType, ops::InverseGradOpMaker, - ops::InverseGradOpMaker); - -REGISTER_OPERATOR(inverse_grad, ops::InverseGradOp); + ops::InverseGradOpMaker, + InverseInferShapeFunctor); -REGISTER_OP_CPU_KERNEL(inverse, - ops::InverseKernel, - ops::InverseKernel); -REGISTER_OP_CPU_KERNEL(inverse_grad, - ops::InverseGradKernel, - ops::InverseGradKernel); +REGISTER_OPERATOR(inverse_grad, + ops::InverseGradOp, + InverseGradInferShapeFunctor); diff --git a/paddle/fluid/operators/inverse_op.cu.cc b/paddle/fluid/operators/inverse_op.cu.cc deleted file mode 100644 index e50476fd23398ea32ef1587e37658fcd107225b7..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/inverse_op.cu.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/inverse_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - inverse, - ops::InverseKernel, - ops::InverseKernel); -REGISTER_OP_CUDA_KERNEL( - inverse_grad, - ops::InverseGradKernel, - ops::InverseGradKernel); diff --git a/paddle/fluid/operators/inverse_op.h b/paddle/fluid/operators/inverse_op.h deleted file mode 100644 index 61f41e9761c62f9878d26da7c3d181fa05f2e666..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/inverse_op.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/matrix_inverse.h" - -namespace paddle { -namespace operators { - -template -class InverseKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); - - auto& dev_ctx = context.template device_context(); - phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, *input, output); - } -}; - -template -class InverseGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* a_inv = context.Input("Output"); - auto* a_inv_grad = - context.Input(framework::GradVarName("Output")); - auto* a_grad = - context.Output(framework::GradVarName("Input")); - - if (a_grad) { - a_grad->mutable_data(context.GetPlace()); - - auto blas = phi::funcs::GetBlas(context); - auto& dev_ctx = context.template device_context(); - framework::Tensor tmp_out = - context.AllocateTmpTensor(a_inv->dims(), dev_ctx); - - auto mat_dim_a0 = - phi::funcs::CreateMatrixDescriptor(a_inv_grad->dims(), 0, false); - auto mat_dim_b0 = - phi::funcs::CreateMatrixDescriptor(a_inv->dims(), 0, true); - blas.MatMul( - *a_inv_grad, mat_dim_a0, *a_inv, mat_dim_b0, T(1), &tmp_out, T(0)); - - auto mat_dim_a1 = - phi::funcs::CreateMatrixDescriptor(a_inv->dims(), 0, true); - auto mat_dim_b1 = - phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false); - blas.MatMul(*a_inv, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), a_grad, T(0)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 69f37c374cffd5bdf4672eea8e6220318eb6492e..ad93a7c6072e730cdcdca728af895fbd988a7560 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1042,6 +1042,15 @@ intermediate : saved_mean, saved_variance backward : instance_norm_grad +- api : inverse + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : InverseInferMeta + kernel : + func : inverse + backward : inverse_grad + # is_empty - api : is_empty args : (Tensor x) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index aa83bb54a03152d39385a20e7f0e7cd5556936fc..61eeec6c848bb69bf674699711ce3bb3f7461688 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -967,6 +967,15 @@ optional : scale backward : instance_norm_double_grad +- backward_api : inverse_grad + forward : inverse(Tensor x) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta: + func : InverseGradInferMeta + kernel : + func : inverse_grad + - backward_api : kldiv_loss_grad forward : kldiv_loss(Tensor x, Tensor label, str reduction) -> Tensor(out) args : (Tensor x, Tensor label, Tensor out_grad, str reduction) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 82eefdea596c9bd4f5a7271a5061975b771ca022..3480af8db88d3ab116e1c64ea13c05e9eecd6be8 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -403,6 +403,14 @@ void InstanceNormDoubleGradInferMeta(const MetaTensor& x, } } +void InverseGradInferMeta(const MetaTensor& out, + const MetaTensor& dout, + MetaTensor* dx) { + if (dx) { + dx->set_dims(dout.dims()); + } +} + void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx) { auto xshape_dims = xshape.dims(); auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size()); diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 042de0b2dd64f3fed557397ced8799fffc57b8f4..88825faa95f7c323de1057192f4ca788cd15fec4 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -183,6 +183,10 @@ void InstanceNormDoubleGradInferMeta(const MetaTensor& x, MetaTensor* dscale, MetaTensor* ddy); +void InverseGradInferMeta(const MetaTensor& out, + const MetaTensor& dout, + MetaTensor* dx); + void KernelWithXShapeInferMeta(const MetaTensor& xshape, MetaTensor* dx); void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c7699c34cc5466130f7c058b74fab7eb3a126ca5..9b7dd1f45f4722dc0172f15efd855ca8b53e846b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1025,6 +1025,43 @@ void InferMetaFromVecValue(const MetaTensor& x, } } +void InverseInferMeta(const MetaTensor& x, MetaTensor* out) { + auto input_dims = x.dims(); + int64_t input_rank = input_dims.size(); + PADDLE_ENFORCE_GE( + input_rank, + 2, + errors::InvalidArgument( + "The dimension of Input(Input) is expected to be no less than 2. " + "But received: Input(Input)'s dimension = %d, shape = [%s].", + input_rank, + input_dims)); + for (int64_t i = 0; i < input_rank; ++i) { + PADDLE_ENFORCE_EQ( + (input_dims[i] == -1) || (input_dims[i] > 0), + true, + errors::InvalidArgument( + "Each dimension of input tensor is expected to be -1 or a " + "positive number, but received %d. Input's shape is [%s].", + input_dims[i], + input_dims)); + } + if (input_dims[input_rank - 2] > 0 && input_dims[input_rank - 1] > 0) { + PADDLE_ENFORCE_EQ(input_dims[input_rank - 2], + input_dims[input_rank - 1], + errors::InvalidArgument( + "The last two dimensions are expected to be equal. " + "But received: %d and %d; " + "Input(Input)'s shape = [%s].", + input_dims[input_rank - 2], + input_dims[input_rank - 1], + input_dims)); + } + + out->set_dims(input_dims); + out->share_lod(x); +} + void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_dims(phi::make_ddim({1})); out->set_dtype(DataType::BOOL); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index ea7364e643960c87ffd8f14e4c5e1058ee1db85c..805fa3a56d442e3e534a9e87b8e929c54f1fd291 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -146,6 +146,8 @@ void InferMetaFromVecValue(const MetaTensor& x, const std::vector& shape, MetaTensor* out); +void InverseInferMeta(const MetaTensor& x, MetaTensor* out); + void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out); void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/inverse_grad_kernel.cc b/paddle/phi/kernels/cpu/inverse_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..97c10e69c8eabea99a92cf7d3210f398efa33419 --- /dev/null +++ b/paddle/phi/kernels/cpu/inverse_grad_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + inverse_grad, CPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/inverse_kernel.cc b/paddle/phi/kernels/cpu/inverse_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b21718eca3f2cfb1bf68399c48af8d309e95a62 --- /dev/null +++ b/paddle/phi/kernels/cpu/inverse_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/inverse_kernel_impl.h" + +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + inverse, CPU, ALL_LAYOUT, phi::InverseKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/inverse_grad_kernel.cu b/paddle/phi/kernels/gpu/inverse_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2fdc02934fedc6851575a26bfcf151088972e307 --- /dev/null +++ b/paddle/phi/kernels/gpu/inverse_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/inverse_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + inverse_grad, GPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/inverse_kernel.cu b/paddle/phi/kernels/gpu/inverse_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..4c011337c6f8f1df865af162122bc57e1ada44b9 --- /dev/null +++ b/paddle/phi/kernels/gpu/inverse_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/inverse_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/inverse_kernel_impl.h" + +PD_REGISTER_KERNEL( + inverse, GPU, ALL_LAYOUT, phi::InverseKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h b/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..26e2898bf73ff1564d482f38ffa29f186b011c32 --- /dev/null +++ b/paddle/phi/kernels/impl/inverse_grad_kernel_impl.h @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/inverse_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +namespace phi { + +template +void InverseGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* in_grad) { + if (in_grad) { + dev_ctx.template Alloc(in_grad); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + DenseTensor tmp_out; + tmp_out.Resize(out.dims()); + dev_ctx.template Alloc(&tmp_out); + + auto mat_dim_a0 = + phi::funcs::CreateMatrixDescriptor(out_grad.dims(), 0, false); + auto mat_dim_b0 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + blas.MatMul(out_grad, mat_dim_a0, out, mat_dim_b0, T(1), &tmp_out, T(0)); + + auto mat_dim_a1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + auto mat_dim_b1 = + phi::funcs::CreateMatrixDescriptor(tmp_out.dims(), 0, false); + blas.MatMul(out, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), in_grad, T(0)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/inverse_kernel_impl.h b/paddle/phi/kernels/impl/inverse_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..351c3933a12af9a369807a3aa3781dabe4081b66 --- /dev/null +++ b/paddle/phi/kernels/impl/inverse_kernel_impl.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/inverse_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +namespace phi { + +template +void InverseKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + + phi::funcs::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, x, out); +} + +} // namespace phi diff --git a/paddle/phi/kernels/inverse_grad_kernel.h b/paddle/phi/kernels/inverse_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3fccb87897ab009b2075f1f3f4e47c499d5d6d9e --- /dev/null +++ b/paddle/phi/kernels/inverse_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void InverseGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* in_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/inverse_kernel.h b/paddle/phi/kernels/inverse_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d8ebf39c362db3b59344411b1146f9d23e39993a --- /dev/null +++ b/paddle/phi/kernels/inverse_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +namespace phi { + +template +void InverseKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/inverse_sig.cc b/paddle/phi/ops/compat/inverse_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ec56d5759a963ebddf6f2267a433b6acacaa7d2 --- /dev/null +++ b/paddle/phi/ops/compat/inverse_sig.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { +KernelSignature InverseGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "inverse_grad", {"Output", "Output@GRAD"}, {}, {"Input@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(inverse_grad, phi::InverseGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_inverse_op.py b/python/paddle/fluid/tests/unittests/test_inverse_op.py index b868fef15ec329a51acc8819880ec073144773be..d39707f10426695910131afb2637fdb2a901bb81 100644 --- a/python/paddle/fluid/tests/unittests/test_inverse_op.py +++ b/python/paddle/fluid/tests/unittests/test_inverse_op.py @@ -25,6 +25,7 @@ class TestInverseOp(OpTest): def config(self): self.matrix_shape = [10, 10] self.dtype = "float64" + self.python_api = paddle.tensor.math.inverse def setUp(self): self.op_type = "inverse" @@ -38,10 +39,10 @@ class TestInverseOp(OpTest): self.outputs = {'Output': inverse} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_grad(self): - self.check_grad(['Input'], 'Output') + self.check_grad(['Input'], 'Output', check_eager=True) class TestInverseOpBatched(TestInverseOp): @@ -49,6 +50,7 @@ class TestInverseOpBatched(TestInverseOp): def config(self): self.matrix_shape = [8, 4, 4] self.dtype = "float64" + self.python_api = paddle.tensor.math.inverse class TestInverseOpLarge(TestInverseOp): @@ -56,9 +58,13 @@ class TestInverseOpLarge(TestInverseOp): def config(self): self.matrix_shape = [32, 32] self.dtype = "float64" + self.python_api = paddle.tensor.math.inverse def test_grad(self): - self.check_grad(['Input'], 'Output', max_relative_error=1e-6) + self.check_grad(['Input'], + 'Output', + max_relative_error=1e-6, + check_eager=True) class TestInverseOpFP32(TestInverseOp): @@ -66,9 +72,13 @@ class TestInverseOpFP32(TestInverseOp): def config(self): self.matrix_shape = [10, 10] self.dtype = "float32" + self.python_api = paddle.tensor.math.inverse def test_grad(self): - self.check_grad(['Input'], 'Output', max_relative_error=1e-2) + self.check_grad(['Input'], + 'Output', + max_relative_error=1e-2, + check_eager=True) class TestInverseOpBatchedFP32(TestInverseOpFP32): @@ -76,6 +86,7 @@ class TestInverseOpBatchedFP32(TestInverseOpFP32): def config(self): self.matrix_shape = [8, 4, 4] self.dtype = "float32" + self.python_api = paddle.tensor.math.inverse class TestInverseOpLargeFP32(TestInverseOpFP32): @@ -83,6 +94,7 @@ class TestInverseOpLargeFP32(TestInverseOpFP32): def config(self): self.matrix_shape = [32, 32] self.dtype = "float32" + self.python_api = paddle.tensor.math.inverse class TestInverseAPI(unittest.TestCase): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 4fc725bf91303c4b10d06d196d77b3100ea7b63b..4b9fbfb1fee46fe86007dc1e88a209aa8555a406 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1932,7 +1932,9 @@ def inverse(x, name=None): print(inv) # [[0.5, 0], [0, 0.5]] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_inverse(x) + elif paddle.in_dynamic_mode(): return _C_ops.inverse(x) def _check_input(x):