未验证 提交 5dfb87d9 编写于 作者: W Weilong Wu 提交者: GitHub

[Phi] Migrate infermeta and add yaml for solve op (#44379)

* migrate solve kernel to phi

* re useless header file, fix a bug in grad_kernel_impl

* add header file in need

* add yaml for solve op

* fix solve_sig.cc ArgumentMapping and update tests case

* disable legacy dygraph check in op_test

* rm solve_op.cc / solve_sig.cc and migrate yaml config

* Update op_test.py

disable legacy dygraph check when check_eager is True
上级 6fb2958e
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class SolveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Solve");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Solve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
std::vector<int64_t> x_dims_vec = phi::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> y_dims_vec = phi::vectorize(ctx->GetInputDim("Y"));
auto x_dims_n = x_dims_vec.size();
auto y_dims_n = y_dims_vec.size();
PADDLE_ENFORCE_GT(x_dims_n,
1,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of SolveOp "
"should be larger than 1. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims_n,
x_dims));
PADDLE_ENFORCE_GE(y_dims_n,
1,
platform::errors::InvalidArgument(
"The input tensor Y's dimensions of SolveOp "
"should be larger than or equal 1. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims_n,
y_dims));
PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1]));
bool x_broadcasted = false, y_broadcasted = false;
bool trans_x = false, trans_y = false;
if (x_dims_n == 1) {
x_dims_vec.insert(x_dims_vec.begin(), 1);
x_dims_n = 2;
x_broadcasted = true;
}
if (y_dims_n == 1) {
y_dims_vec.push_back(1);
y_dims_n = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = x_dims_vec[x_dims_n - 1];
} else {
M = x_dims_vec[x_dims_n - 2];
}
if (trans_y) {
N = y_dims_vec[y_dims_n - 2];
} else {
N = y_dims_vec[y_dims_n - 1];
}
std::vector<int64_t> new_dims;
if (x_dims_n >= y_dims_n) {
new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2);
} else {
new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2);
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto out_dims = phi::make_ddim(new_dims);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
customized_type_value);
}
};
class SolveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of solve op.");
AddInput("Y", "(Tensor), The second input tensor of solve op.");
AddOutput("Out", "(Tensor), The output tensor of solve op.");
AddComment(R"DOC(
Solve Operator.
This operator is used to computes the solution of a square system of
linear equations with a unique solution for input $X$ and $Y$.
The equation is:
$$Out = X^-1 * Y$$
)DOC");
}
};
class SolveOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class SolveGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "solve");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"solve");
// reuse the linalg.solve forward output
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "solve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class SolveOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("solve_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
// reuse the linalg.solve forward output
retv->SetInput("Out", this->Output("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(solve,
ops::SolveOp,
ops::SolveOpMaker,
ops::SolveOpInferVarType,
ops::SolveOpGradMaker<paddle::framework::OpDesc>,
ops::SolveOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(solve_grad, ops::SolveGradOp);
......@@ -116,6 +116,16 @@
func : poisson
backward : poisson_grad
- api : solve
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : SolveInferMeta
kernel :
func : solve
data_type : x
backward : solve_grad
- api : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
......
......@@ -89,6 +89,12 @@
outputs :
out : Out
- api : solve
inputs :
{x : X, y : Y}
outputs :
out : Out
- api : trace
inputs :
x : Input
......
......@@ -125,6 +125,16 @@
kernel :
func : poisson_grad
- backward_api : solve_grad
forward : solve (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : solve_grad
- backward_api : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
......
......@@ -2082,6 +2082,93 @@ void ValueCompareInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL);
}
void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
std::vector<int64_t> y_dims_vec = phi::vectorize(y.dims());
auto x_dims_n = x_dims_vec.size();
auto y_dims_n = y_dims_vec.size();
PADDLE_ENFORCE_GT(
x_dims_n,
1,
phi::errors::InvalidArgument("The input tensor X's dimensions of SolveOp "
"should be larger than 1. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims_n,
x_dims));
PADDLE_ENFORCE_GE(y_dims_n,
1,
phi::errors::InvalidArgument(
"The input tensor Y's dimensions of SolveOp "
"should be larger than or equal 1. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims_n,
y_dims));
PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1],
phi::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1]));
bool x_broadcasted = false, y_broadcasted = false;
bool trans_x = false, trans_y = false;
if (x_dims_n == 1) {
x_dims_vec.insert(x_dims_vec.begin(), 1);
x_dims_n = 2;
x_broadcasted = true;
}
if (y_dims_n == 1) {
y_dims_vec.push_back(1);
y_dims_n = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = x_dims_vec[x_dims_n - 1];
} else {
M = x_dims_vec[x_dims_n - 2];
}
if (trans_y) {
N = y_dims_vec[y_dims_n - 2];
} else {
N = y_dims_vec[y_dims_n - 1];
}
std::vector<int64_t> new_dims;
if (x_dims_n >= y_dims_n) {
new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2);
} else {
new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2);
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto out_dims = phi::make_ddim(new_dims);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
......
......@@ -305,4 +305,6 @@ void ValueCompareInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
} // namespace phi
......@@ -73,8 +73,8 @@ template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
bool is_vector = false;
......
......@@ -22,8 +22,8 @@ template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);
......
......@@ -18,6 +18,15 @@ limitations under the License. */
namespace phi {
/**
* @brief This kernrel is used to computes the solution of a square system of
* linear equations with a unique solution for input x and y.
* $$Out = X^-1 * Y$$
* @param ctx device context
* @param x the input tensor of solve
* @param y the input tensor of solve
* @param out the output tensor of solve
*/
template <typename T, typename Context>
void SolveKernel(const Context& dev_ctx,
const DenseTensor& x,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SolveGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"solve_grad", {"X", "Y", "Out@GRAD", "Out"}, {}, {"X@GRAD", "Y@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(solve_grad, phi::SolveGradOpArgumentMapping);
......@@ -1368,6 +1368,10 @@ class OpTest(unittest.TestCase):
inplace_atol=None,
check_eager=False):
# disable legacy dygraph check when check_eager is True
if check_eager == True:
check_dygraph = False
def find_imperative_actual(target_name, dygraph_outs, place):
for name in dygraph_outs:
if name == target_name:
......@@ -1692,7 +1696,8 @@ class OpTest(unittest.TestCase):
inplace_atol=inplace_atol)
if check_eager:
return outs, dygraph_outs, eager_dygraph_outs, fetch_list
assert check_dygraph == False
return outs, eager_dygraph_outs, fetch_list
elif check_dygraph:
return outs, dygraph_outs, fetch_list
else:
......@@ -1767,6 +1772,11 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
inplace_atol=None,
check_eager=False):
# disable legacy dygraph check when check_eager is True
if check_eager == True:
check_dygraph = False
self.__class__.op_type = self.op_type
if self.is_mkldnn_op():
self.__class__.use_mkldnn = True
......@@ -1784,8 +1794,8 @@ class OpTest(unittest.TestCase):
inplace_atol,
check_eager=check_eager)
if check_eager:
assert check_dygraph == True
outs, dygraph_outs, eager_dygraph_outs, fetch_list = res
assert check_dygraph == False
outs, eager_dygraph_outs, fetch_list = res
elif check_dygraph:
outs, dygraph_outs, fetch_list = res
else:
......@@ -1859,6 +1869,11 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
check_dygraph=True,
check_eager=False):
# disable legacy dygraph check when check_eager is True
if check_eager == True:
check_dygraph = False
self._check_grad_helper()
places = self._get_places()
for place in places:
......@@ -1887,6 +1902,11 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
numeric_place=None,
check_eager=False):
# disable legacy dygraph check when check_eager is True
if check_eager == True:
check_dygraph = False
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......
......@@ -24,12 +24,14 @@ sys.path.append("..")
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
# 2D normal case
class TestSolveOp(OpTest):
def config(self):
self.python_api = paddle.linalg.solve
self.input_x_matrix_shape = [15, 15]
self.input_y_matrix_shape = [15, 10]
self.dtype = "float64"
......@@ -49,16 +51,17 @@ class TestSolveOp(OpTest):
}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
# x broadcast + 3D batch case
class TestSolveOpBatched_case0(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -70,16 +73,20 @@ class TestSolveOpBatched_case0(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-1)
self.check_grad(['X', 'Y'],
'Out',
max_relative_error=1e-1,
check_eager=True)
# 3D batch + y vector case
class TestSolveOpBatched_case1(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -91,16 +98,20 @@ class TestSolveOpBatched_case1(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04)
self.check_grad(['X', 'Y'],
'Out',
max_relative_error=0.04,
check_eager=True)
# 3D batch + y broadcast case
class TestSolveOpBatched_case2(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -112,16 +123,20 @@ class TestSolveOpBatched_case2(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02)
self.check_grad(['X', 'Y'],
'Out',
max_relative_error=0.02,
check_eager=True)
# x broadcast + 3D batch case
class TestSolveOpBatched_case3(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -133,16 +148,20 @@ class TestSolveOpBatched_case3(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02)
self.check_grad(['X', 'Y'],
'Out',
max_relative_error=0.02,
check_eager=True)
# 3D normal batch case
class TestSolveOpBatched_case4(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -154,16 +173,17 @@ class TestSolveOpBatched_case4(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
# 4D normal batch case
class TestSolveOpBatched_case5(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -175,16 +195,17 @@ class TestSolveOpBatched_case5(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
# 4D batch + y broadcast case
class TestSolveOpBatched_case6(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -196,16 +217,17 @@ class TestSolveOpBatched_case6(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
# 5D normal batch case
class TestSolveOpBatched_case7(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -217,16 +239,20 @@ class TestSolveOpBatched_case7(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04)
self.check_grad(['X', 'Y'],
'Out',
max_relative_error=0.04,
check_eager=True)
# 5D batch + y broadcast case
class TestSolveOpBatched_case8(OpTest):
def setUp(self):
self.python_api = paddle.linalg.solve
self.op_type = "solve"
self.dtype = "float64"
np.random.seed(2021)
......@@ -238,15 +264,18 @@ class TestSolveOpBatched_case8(OpTest):
self.outputs = {'Out': result}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04)
self.check_grad(['X', 'Y'],
'Out',
max_relative_error=0.04,
check_eager=True)
class TestSolveOpError(unittest.TestCase):
def test_errors(self):
def func_errors(self):
with program_guard(Program(), Program()):
# The input type of solve_op must be Variable.
x1 = fluid.create_lod_tensor(np.array([[-1]]), [[1]],
......@@ -282,6 +311,11 @@ class TestSolveOpError(unittest.TestCase):
y7 = fluid.data(name="y7", shape=[2, 4, 3], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.solve, x7, y7)
def test_dygraph(self):
with _test_eager_guard():
self.func_errors()
self.func_errors()
# 2D + vector case, FP64
class TestSolveOpAPI_1(unittest.TestCase):
......@@ -323,7 +357,7 @@ class TestSolveOpAPI_1(unittest.TestCase):
for place in self.place:
self.check_static_result(place=place)
def test_dygraph(self):
def func_dygraph(self):
def run(place):
paddle.disable_static(place)
......@@ -344,6 +378,11 @@ class TestSolveOpAPI_1(unittest.TestCase):
for place in self.place:
run(place)
def test_dygraph(self):
with _test_eager_guard():
self.func_dygraph()
self.func_dygraph()
# 2D normal case, FP64
class TestSolveOpAPI_2(unittest.TestCase):
......@@ -386,14 +425,13 @@ class TestSolveOpAPI_2(unittest.TestCase):
for place in self.place:
self.check_static_result(place=place)
def test_dygraph(self):
def func_dygraph(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
input_x_np = np.random.random([10, 10]).astype(self.dtype)
input_y_np = np.random.random([10, 4]).astype(self.dtype)
tensor_input_x = paddle.to_tensor(input_x_np)
tensor_input_y = paddle.to_tensor(input_y_np)
......@@ -407,6 +445,11 @@ class TestSolveOpAPI_2(unittest.TestCase):
for place in self.place:
run(place)
def test_dygraph(self):
with _test_eager_guard():
self.func_dygraph()
self.func_dygraph()
# 2D normal case, FP32
class TestSolveOpAPI_3(unittest.TestCase):
......@@ -450,7 +493,7 @@ class TestSolveOpAPI_3(unittest.TestCase):
for place in self.place:
self.check_static_result(place=place)
def test_dygraph(self):
def func_dygraph(self):
def run(place):
paddle.disable_static(place)
......@@ -472,6 +515,11 @@ class TestSolveOpAPI_3(unittest.TestCase):
for place in self.place:
run(place)
def test_dygraph(self):
with _test_eager_guard():
self.func_dygraph()
self.func_dygraph()
# 3D + y broadcast case, FP64
class TestSolveOpAPI_4(unittest.TestCase):
......@@ -513,7 +561,7 @@ class TestSolveOpAPI_4(unittest.TestCase):
for place in self.place:
self.check_static_result(place=place)
def test_dygraph(self):
def func_dygraph(self):
def run(place):
paddle.disable_static(place)
......@@ -534,6 +582,11 @@ class TestSolveOpAPI_4(unittest.TestCase):
for place in self.place:
run(place)
def test_dygraph(self):
with _test_eager_guard():
self.func_dygraph()
self.func_dygraph()
class TestSolveOpSingularAPI(unittest.TestCase):
# Singular matrix is ​​not invertible
......@@ -573,14 +626,13 @@ class TestSolveOpSingularAPI(unittest.TestCase):
paddle.enable_static()
self.check_static_result(place=place)
def test_dygraph(self):
def func_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_x_np = np.ones([4, 4]).astype(self.dtype)
input_y_np = np.ones([4, 4]).astype(self.dtype)
input_x = fluid.dygraph.to_variable(input_x_np)
input_y = fluid.dygraph.to_variable(input_y_np)
try:
result = paddle.linalg.solve(input_x, input_y)
except RuntimeError as ex:
......@@ -590,6 +642,11 @@ class TestSolveOpSingularAPI(unittest.TestCase):
print("The mat is singular")
pass
def test_dygraph(self):
with _test_eager_guard():
self.func_dygraph()
self.func_dygraph()
if __name__ == "__main__":
unittest.main()
......@@ -2853,7 +2853,10 @@ def solve(x, y, name=None):
print(out)
# [2., 3.])
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_solve(x, y)
if _in_legacy_dygraph():
return _C_ops.solve(x, y)
inputs = {"X": [x], "Y": [y]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册