From 5dfb87d9a7ac2a20e7dd4a31416e641df9182878 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 19 Jul 2022 13:36:54 +0800 Subject: [PATCH] [Phi] Migrate infermeta and add yaml for solve op (#44379) * migrate solve kernel to phi * re useless header file, fix a bug in grad_kernel_impl * add header file in need * add yaml for solve op * fix solve_sig.cc ArgumentMapping and update tests case * disable legacy dygraph check in op_test * rm solve_op.cc / solve_sig.cc and migrate yaml config * Update op_test.py disable legacy dygraph check when check_eager is True --- paddle/fluid/operators/solve_op.cc | 222 ------------------ paddle/phi/api/yaml/api.yaml | 10 + paddle/phi/api/yaml/api_compat.yaml | 6 + paddle/phi/api/yaml/backward.yaml | 10 + paddle/phi/infermeta/binary.cc | 87 +++++++ paddle/phi/infermeta/binary.h | 2 + .../phi/kernels/impl/solve_grad_kernel_impl.h | 2 +- paddle/phi/kernels/solve_grad_kernel.h | 2 +- paddle/phi/kernels/solve_kernel.h | 9 + paddle/phi/ops/compat/solve_sig.cc | 26 -- .../paddle/fluid/tests/unittests/op_test.py | 26 +- .../fluid/tests/unittests/test_solve_op.py | 113 ++++++--- python/paddle/tensor/linalg.py | 5 +- 13 files changed, 238 insertions(+), 282 deletions(-) delete mode 100644 paddle/fluid/operators/solve_op.cc delete mode 100644 paddle/phi/ops/compat/solve_sig.cc diff --git a/paddle/fluid/operators/solve_op.cc b/paddle/fluid/operators/solve_op.cc deleted file mode 100644 index daa020e4a0..0000000000 --- a/paddle/fluid/operators/solve_op.cc +++ /dev/null @@ -1,222 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/ddim.h" - -namespace paddle { -namespace operators { - -using framework::OpKernelType; -using framework::Tensor; - -class SolveOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Solve"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Solve"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Solve"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - std::vector x_dims_vec = phi::vectorize(ctx->GetInputDim("X")); - std::vector y_dims_vec = phi::vectorize(ctx->GetInputDim("Y")); - - auto x_dims_n = x_dims_vec.size(); - auto y_dims_n = y_dims_vec.size(); - - PADDLE_ENFORCE_GT(x_dims_n, - 1, - platform::errors::InvalidArgument( - "The input tensor X's dimensions of SolveOp " - "should be larger than 1. But received X's " - "dimensions = %d, X's shape = [%s]", - x_dims_n, - x_dims)); - - PADDLE_ENFORCE_GE(y_dims_n, - 1, - platform::errors::InvalidArgument( - "The input tensor Y's dimensions of SolveOp " - "should be larger than or equal 1. But received Y's " - "dimensions = %d, Y's shape = [%s]", - y_dims_n, - y_dims)); - - PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], - x_dims[x_dims_n - 1], - platform::errors::InvalidArgument( - "The inner-most 2 dimensions of Input(X) all should " - "be square matrices " - "But received X's shape[-2] = %d and shape[-1] = %d.", - x_dims[x_dims_n - 2], - x_dims[x_dims_n - 1])); - - bool x_broadcasted = false, y_broadcasted = false; - bool trans_x = false, trans_y = false; - if (x_dims_n == 1) { - x_dims_vec.insert(x_dims_vec.begin(), 1); - x_dims_n = 2; - x_broadcasted = true; - } - - if (y_dims_n == 1) { - y_dims_vec.push_back(1); - y_dims_n = 2; - y_broadcasted = true; - } - - size_t M, N; - if (trans_x) { - M = x_dims_vec[x_dims_n - 1]; - } else { - M = x_dims_vec[x_dims_n - 2]; - } - if (trans_y) { - N = y_dims_vec[y_dims_n - 2]; - } else { - N = y_dims_vec[y_dims_n - 1]; - } - - std::vector new_dims; - if (x_dims_n >= y_dims_n) { - new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2); - } else { - new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2); - } - if (!x_broadcasted) { - new_dims.push_back(M); - } - if (!y_broadcasted) { - new_dims.push_back(N); - } - if (x_broadcasted && y_broadcasted) { - new_dims.push_back(1); - } - - auto out_dims = phi::make_ddim(new_dims); - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); - } -}; - -class SolveOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The first input tensor of solve op."); - AddInput("Y", "(Tensor), The second input tensor of solve op."); - AddOutput("Out", "(Tensor), The output tensor of solve op."); - AddComment(R"DOC( - Solve Operator. - This operator is used to computes the solution of a square system of - linear equations with a unique solution for input $X$ and $Y$. - - The equation is: - $$Out = X^-1 * Y$$ -)DOC"); - } -}; - -class SolveOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -class SolveGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "solve"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "solve"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "solve"); - // reuse the linalg.solve forward output - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "solve"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); - } - } -}; - -template -class SolveOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("solve_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Y", this->Input("Y")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - // reuse the linalg.solve forward output - retv->SetInput("Out", this->Output("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(solve, - ops::SolveOp, - ops::SolveOpMaker, - ops::SolveOpInferVarType, - ops::SolveOpGradMaker, - ops::SolveOpGradMaker); - -REGISTER_OPERATOR(solve_grad, ops::SolveGradOp); diff --git a/paddle/phi/api/yaml/api.yaml b/paddle/phi/api/yaml/api.yaml index b5703aa57f..b81cb5be42 100644 --- a/paddle/phi/api/yaml/api.yaml +++ b/paddle/phi/api/yaml/api.yaml @@ -116,6 +116,16 @@ func : poisson backward : poisson_grad +- api : solve + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : SolveInferMeta + kernel : + func : solve + data_type : x + backward : solve_grad + - api : trace args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1) output : Tensor diff --git a/paddle/phi/api/yaml/api_compat.yaml b/paddle/phi/api/yaml/api_compat.yaml index a68de3a0f1..fe1ec03206 100644 --- a/paddle/phi/api/yaml/api_compat.yaml +++ b/paddle/phi/api/yaml/api_compat.yaml @@ -89,6 +89,12 @@ outputs : out : Out +- api : solve + inputs : + {x : X, y : Y} + outputs : + out : Out + - api : trace inputs : x : Input diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 17409f8ae7..009875be18 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -125,6 +125,16 @@ kernel : func : poisson_grad +- backward_api : solve_grad + forward : solve (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : solve_grad + - backward_api : trace_grad forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 460b0a9e1b..1bbcd52e8b 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2082,6 +2082,93 @@ void ValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + std::vector x_dims_vec = phi::vectorize(x.dims()); + std::vector y_dims_vec = phi::vectorize(y.dims()); + + auto x_dims_n = x_dims_vec.size(); + auto y_dims_n = y_dims_vec.size(); + + PADDLE_ENFORCE_GT( + x_dims_n, + 1, + phi::errors::InvalidArgument("The input tensor X's dimensions of SolveOp " + "should be larger than 1. But received X's " + "dimensions = %d, X's shape = [%s]", + x_dims_n, + x_dims)); + + PADDLE_ENFORCE_GE(y_dims_n, + 1, + phi::errors::InvalidArgument( + "The input tensor Y's dimensions of SolveOp " + "should be larger than or equal 1. But received Y's " + "dimensions = %d, Y's shape = [%s]", + y_dims_n, + y_dims)); + + PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], + x_dims[x_dims_n - 1], + phi::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should " + "be square matrices " + "But received X's shape[-2] = %d and shape[-1] = %d.", + x_dims[x_dims_n - 2], + x_dims[x_dims_n - 1])); + + bool x_broadcasted = false, y_broadcasted = false; + bool trans_x = false, trans_y = false; + if (x_dims_n == 1) { + x_dims_vec.insert(x_dims_vec.begin(), 1); + x_dims_n = 2; + x_broadcasted = true; + } + + if (y_dims_n == 1) { + y_dims_vec.push_back(1); + y_dims_n = 2; + y_broadcasted = true; + } + + size_t M, N; + if (trans_x) { + M = x_dims_vec[x_dims_n - 1]; + } else { + M = x_dims_vec[x_dims_n - 2]; + } + if (trans_y) { + N = y_dims_vec[y_dims_n - 2]; + } else { + N = y_dims_vec[y_dims_n - 1]; + } + + std::vector new_dims; + if (x_dims_n >= y_dims_n) { + new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2); + } else { + new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2); + } + if (!x_broadcasted) { + new_dims.push_back(M); + } + if (!y_broadcasted) { + new_dims.push_back(N); + } + if (x_broadcasted && y_broadcasted) { + new_dims.push_back(1); + } + + auto out_dims = phi::make_ddim(new_dims); + + out->set_dims(out_dims); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); +} + } // namespace phi PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 12922ed536..70dafe24fb 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -305,4 +305,6 @@ void ValueCompareInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/impl/solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/solve_grad_kernel_impl.h index 55ee023cb5..214db79383 100644 --- a/paddle/phi/kernels/impl/solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/solve_grad_kernel_impl.h @@ -73,8 +73,8 @@ template void SolveGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - const DenseTensor& dout, const DenseTensor& out, + const DenseTensor& dout, DenseTensor* dx, DenseTensor* dy) { bool is_vector = false; diff --git a/paddle/phi/kernels/solve_grad_kernel.h b/paddle/phi/kernels/solve_grad_kernel.h index 31bdb9932b..d2f1b6aef7 100644 --- a/paddle/phi/kernels/solve_grad_kernel.h +++ b/paddle/phi/kernels/solve_grad_kernel.h @@ -22,8 +22,8 @@ template void SolveGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - const DenseTensor& dout, const DenseTensor& out, + const DenseTensor& dout, DenseTensor* dx, DenseTensor* dy); diff --git a/paddle/phi/kernels/solve_kernel.h b/paddle/phi/kernels/solve_kernel.h index 28dddb0f64..72f2feb475 100644 --- a/paddle/phi/kernels/solve_kernel.h +++ b/paddle/phi/kernels/solve_kernel.h @@ -18,6 +18,15 @@ limitations under the License. */ namespace phi { +/** + * @brief This kernrel is used to computes the solution of a square system of + * linear equations with a unique solution for input x and y. + * $$Out = X^-1 * Y$$ + * @param ctx device context + * @param x the input tensor of solve + * @param y the input tensor of solve + * @param out the output tensor of solve + */ template void SolveKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/ops/compat/solve_sig.cc b/paddle/phi/ops/compat/solve_sig.cc deleted file mode 100644 index 9771adee8e..0000000000 --- a/paddle/phi/ops/compat/solve_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature SolveGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "solve_grad", {"X", "Y", "Out@GRAD", "Out"}, {}, {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(solve_grad, phi::SolveGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 860a72193e..b0274431d4 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1368,6 +1368,10 @@ class OpTest(unittest.TestCase): inplace_atol=None, check_eager=False): + # disable legacy dygraph check when check_eager is True + if check_eager == True: + check_dygraph = False + def find_imperative_actual(target_name, dygraph_outs, place): for name in dygraph_outs: if name == target_name: @@ -1692,7 +1696,8 @@ class OpTest(unittest.TestCase): inplace_atol=inplace_atol) if check_eager: - return outs, dygraph_outs, eager_dygraph_outs, fetch_list + assert check_dygraph == False + return outs, eager_dygraph_outs, fetch_list elif check_dygraph: return outs, dygraph_outs, fetch_list else: @@ -1767,6 +1772,11 @@ class OpTest(unittest.TestCase): check_dygraph=True, inplace_atol=None, check_eager=False): + + # disable legacy dygraph check when check_eager is True + if check_eager == True: + check_dygraph = False + self.__class__.op_type = self.op_type if self.is_mkldnn_op(): self.__class__.use_mkldnn = True @@ -1784,8 +1794,8 @@ class OpTest(unittest.TestCase): inplace_atol, check_eager=check_eager) if check_eager: - assert check_dygraph == True - outs, dygraph_outs, eager_dygraph_outs, fetch_list = res + assert check_dygraph == False + outs, eager_dygraph_outs, fetch_list = res elif check_dygraph: outs, dygraph_outs, fetch_list = res else: @@ -1859,6 +1869,11 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, check_dygraph=True, check_eager=False): + + # disable legacy dygraph check when check_eager is True + if check_eager == True: + check_dygraph = False + self._check_grad_helper() places = self._get_places() for place in places: @@ -1887,6 +1902,11 @@ class OpTest(unittest.TestCase): check_dygraph=True, numeric_place=None, check_eager=False): + + # disable legacy dygraph check when check_eager is True + if check_eager == True: + check_dygraph = False + self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict() diff --git a/python/paddle/fluid/tests/unittests/test_solve_op.py b/python/paddle/fluid/tests/unittests/test_solve_op.py index 99c5eb21db..3162ecffba 100644 --- a/python/paddle/fluid/tests/unittests/test_solve_op.py +++ b/python/paddle/fluid/tests/unittests/test_solve_op.py @@ -24,12 +24,14 @@ sys.path.append("..") from op_test import OpTest import paddle.fluid as fluid from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard # 2D normal case class TestSolveOp(OpTest): def config(self): + self.python_api = paddle.linalg.solve self.input_x_matrix_shape = [15, 15] self.input_y_matrix_shape = [15, 10] self.dtype = "float64" @@ -49,16 +51,17 @@ class TestSolveOp(OpTest): } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) # x broadcast + 3D batch case class TestSolveOpBatched_case0(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -70,16 +73,20 @@ class TestSolveOpBatched_case0(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-1) + self.check_grad(['X', 'Y'], + 'Out', + max_relative_error=1e-1, + check_eager=True) # 3D batch + y vector case class TestSolveOpBatched_case1(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -91,16 +98,20 @@ class TestSolveOpBatched_case1(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + self.check_grad(['X', 'Y'], + 'Out', + max_relative_error=0.04, + check_eager=True) # 3D batch + y broadcast case class TestSolveOpBatched_case2(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -112,16 +123,20 @@ class TestSolveOpBatched_case2(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + self.check_grad(['X', 'Y'], + 'Out', + max_relative_error=0.02, + check_eager=True) # x broadcast + 3D batch case class TestSolveOpBatched_case3(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -133,16 +148,20 @@ class TestSolveOpBatched_case3(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + self.check_grad(['X', 'Y'], + 'Out', + max_relative_error=0.02, + check_eager=True) # 3D normal batch case class TestSolveOpBatched_case4(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -154,16 +173,17 @@ class TestSolveOpBatched_case4(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) # 4D normal batch case class TestSolveOpBatched_case5(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -175,16 +195,17 @@ class TestSolveOpBatched_case5(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) # 4D batch + y broadcast case class TestSolveOpBatched_case6(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -196,16 +217,17 @@ class TestSolveOpBatched_case6(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') + self.check_grad(['X', 'Y'], 'Out', check_eager=True) # 5D normal batch case class TestSolveOpBatched_case7(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -217,16 +239,20 @@ class TestSolveOpBatched_case7(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + self.check_grad(['X', 'Y'], + 'Out', + max_relative_error=0.04, + check_eager=True) # 5D batch + y broadcast case class TestSolveOpBatched_case8(OpTest): def setUp(self): + self.python_api = paddle.linalg.solve self.op_type = "solve" self.dtype = "float64" np.random.seed(2021) @@ -238,15 +264,18 @@ class TestSolveOpBatched_case8(OpTest): self.outputs = {'Out': result} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + self.check_grad(['X', 'Y'], + 'Out', + max_relative_error=0.04, + check_eager=True) class TestSolveOpError(unittest.TestCase): - def test_errors(self): + def func_errors(self): with program_guard(Program(), Program()): # The input type of solve_op must be Variable. x1 = fluid.create_lod_tensor(np.array([[-1]]), [[1]], @@ -282,6 +311,11 @@ class TestSolveOpError(unittest.TestCase): y7 = fluid.data(name="y7", shape=[2, 4, 3], dtype="float64") self.assertRaises(ValueError, paddle.linalg.solve, x7, y7) + def test_dygraph(self): + with _test_eager_guard(): + self.func_errors() + self.func_errors() + # 2D + vector case, FP64 class TestSolveOpAPI_1(unittest.TestCase): @@ -323,7 +357,7 @@ class TestSolveOpAPI_1(unittest.TestCase): for place in self.place: self.check_static_result(place=place) - def test_dygraph(self): + def func_dygraph(self): def run(place): paddle.disable_static(place) @@ -344,6 +378,11 @@ class TestSolveOpAPI_1(unittest.TestCase): for place in self.place: run(place) + def test_dygraph(self): + with _test_eager_guard(): + self.func_dygraph() + self.func_dygraph() + # 2D normal case, FP64 class TestSolveOpAPI_2(unittest.TestCase): @@ -386,14 +425,13 @@ class TestSolveOpAPI_2(unittest.TestCase): for place in self.place: self.check_static_result(place=place) - def test_dygraph(self): + def func_dygraph(self): def run(place): paddle.disable_static(place) np.random.seed(2021) input_x_np = np.random.random([10, 10]).astype(self.dtype) input_y_np = np.random.random([10, 4]).astype(self.dtype) - tensor_input_x = paddle.to_tensor(input_x_np) tensor_input_y = paddle.to_tensor(input_y_np) @@ -407,6 +445,11 @@ class TestSolveOpAPI_2(unittest.TestCase): for place in self.place: run(place) + def test_dygraph(self): + with _test_eager_guard(): + self.func_dygraph() + self.func_dygraph() + # 2D normal case, FP32 class TestSolveOpAPI_3(unittest.TestCase): @@ -450,7 +493,7 @@ class TestSolveOpAPI_3(unittest.TestCase): for place in self.place: self.check_static_result(place=place) - def test_dygraph(self): + def func_dygraph(self): def run(place): paddle.disable_static(place) @@ -472,6 +515,11 @@ class TestSolveOpAPI_3(unittest.TestCase): for place in self.place: run(place) + def test_dygraph(self): + with _test_eager_guard(): + self.func_dygraph() + self.func_dygraph() + # 3D + y broadcast case, FP64 class TestSolveOpAPI_4(unittest.TestCase): @@ -513,7 +561,7 @@ class TestSolveOpAPI_4(unittest.TestCase): for place in self.place: self.check_static_result(place=place) - def test_dygraph(self): + def func_dygraph(self): def run(place): paddle.disable_static(place) @@ -534,6 +582,11 @@ class TestSolveOpAPI_4(unittest.TestCase): for place in self.place: run(place) + def test_dygraph(self): + with _test_eager_guard(): + self.func_dygraph() + self.func_dygraph() + class TestSolveOpSingularAPI(unittest.TestCase): # Singular matrix is ​​not invertible @@ -573,14 +626,13 @@ class TestSolveOpSingularAPI(unittest.TestCase): paddle.enable_static() self.check_static_result(place=place) - def test_dygraph(self): + def func_dygraph(self): for place in self.places: with fluid.dygraph.guard(place): input_x_np = np.ones([4, 4]).astype(self.dtype) input_y_np = np.ones([4, 4]).astype(self.dtype) input_x = fluid.dygraph.to_variable(input_x_np) input_y = fluid.dygraph.to_variable(input_y_np) - try: result = paddle.linalg.solve(input_x, input_y) except RuntimeError as ex: @@ -590,6 +642,11 @@ class TestSolveOpSingularAPI(unittest.TestCase): print("The mat is singular") pass + def test_dygraph(self): + with _test_eager_guard(): + self.func_dygraph() + self.func_dygraph() + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 1bc85a076a..6893685582 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2853,7 +2853,10 @@ def solve(x, y, name=None): print(out) # [2., 3.]) """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_solve(x, y) + + if _in_legacy_dygraph(): return _C_ops.solve(x, y) inputs = {"X": [x], "Y": [y]} -- GitLab