未验证 提交 25d3dce1 编写于 作者: X xiongkun 提交者: GitHub

transfer the svd infer into phi infermeta (#44528)

* transfer the svd infer into phi infermeta

* remove the svd.h

* modify svd api

* fix svd error by insert optional
上级 8d3672f0
......@@ -17,8 +17,10 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -26,55 +28,9 @@
namespace paddle {
namespace operators {
using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
}
static DDim VHDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
}
static DDim SDDim(const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return phi::make_ddim(x_vec);
}
class SvdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "svd");
OP_INOUT_CHECK(ctx->HasOutput("U"), "Output", "U", "svd");
OP_INOUT_CHECK(ctx->HasOutput("VH"), "Output", "VH", "svd");
OP_INOUT_CHECK(ctx->HasOutput("S"), "Output", "S", "svd");
auto in_dims = ctx->GetInputDim("X");
int x_rank = in_dims.size();
PADDLE_ENFORCE_GE(in_dims.size(),
2,
platform::errors::InvalidArgument(
"the rank of input must greater than 2"));
int m = in_dims[x_rank - 2];
int n = in_dims[x_rank - 1];
int k = std::min(m, n);
const bool full_uv = ctx->Attrs().Get<bool>("full_matrices");
ctx->SetOutputDim("U", !full_uv ? UDDim(in_dims, k) : UDDim(in_dims, m));
ctx->SetOutputDim("VH", !full_uv ? VHDDim(in_dims, k) : VHDDim(in_dims, n));
ctx->SetOutputDim("S", SDDim(in_dims, k));
ctx->ShareLoD("X", /*->*/ "U");
ctx->ShareLoD("X", /*->*/ "VH");
ctx->ShareLoD("X", /*->*/ "S");
}
};
class SvdOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -159,10 +115,15 @@ class SvdGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(svd,
SvdInferShapeFunctor,
PD_INFER_META(phi::SvdInferMeta));
REGISTER_OPERATOR(svd,
ops::SvdOp,
ops::SvdOpMaker,
ops::SvdGradMaker<paddle::framework::OpDesc>,
ops::SvdGradMaker<paddle::imperative::OpBase>);
ops::SvdGradMaker<paddle::imperative::OpBase>,
SvdInferShapeFunctor);
REGISTER_OPERATOR(svd_grad, ops::SvdGradOp);
......@@ -2140,6 +2140,15 @@
data_type : x
backward : sum_grad
- api : svd
args : (Tensor x, bool full_metrices)
output : Tensor(u), Tensor(s), Tensor(vh)
infer_meta :
func : SvdInferMeta
kernel :
func : svd
backward : svd_grad
# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later
- api : swish
args : (Tensor x, float beta=1.0)
......
......@@ -2133,6 +2133,17 @@
output : Tensor(grad_grad_x_grad)
invoke : sum_grad(grad_grad_x, grad_grad_out_grad, dims, keep_dim, reduce_all, grad_grad_x_grad)
- backward_api : svd_grad
forward : svd (Tensor x, bool full) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : svd_grad
optional: u_grad, vh_grad, s_grad
- backward_api : swish_grad
forward : swish (Tensor x, float beta=1.0) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float bete=1.0)
......
......@@ -2715,6 +2715,53 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
MetaTensor* s,
MetaTensor* vh) {
auto UDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
};
auto VHDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
};
auto SDDim = [](const DDim& x_dim, int k) {
// get x_dim and return the ddim of U
auto x_vec = vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return phi::make_ddim(x_vec);
};
auto in_dims = x.dims();
int x_rank = in_dims.size();
PADDLE_ENFORCE_GE(
in_dims.size(),
2,
phi::errors::InvalidArgument("the rank of input must greater than 2"));
int m = in_dims[x_rank - 2];
int n = in_dims[x_rank - 1];
int k = std::min(m, n);
u->set_dims(!full_matrices ? UDDim(in_dims, k) : UDDim(in_dims, m));
vh->set_dims(!full_matrices ? VHDDim(in_dims, k) : VHDDim(in_dims, n));
s->set_dims(SDDim(in_dims, k));
u->share_lod(x);
vh->share_lod(x);
s->share_lod(x);
u->set_dtype(x.dtype());
vh->set_dtype(x.dtype());
s->set_dtype(x.dtype());
}
void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
......
......@@ -387,6 +387,12 @@ void SumRawInferMeta(const MetaTensor& x,
DataType dtype,
MetaTensor* out);
void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
MetaTensor* s,
MetaTensor* vh);
void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
......
......@@ -71,9 +71,9 @@ void SvdGradKernel(const Context& dev_ctx,
const DenseTensor& u,
const DenseTensor& vh,
const DenseTensor& s,
const DenseTensor& u_grad,
const DenseTensor& vh_grad,
const DenseTensor& s_grad,
const paddle::optional<DenseTensor>& u_grad,
const paddle::optional<DenseTensor>& vh_grad,
const paddle::optional<DenseTensor>& s_grad,
bool full_matrices,
DenseTensor* x_grad) {
const auto& dX = *x_grad;
......@@ -87,15 +87,33 @@ void SvdGradKernel(const Context& dev_ctx,
dev_ctx, u, {u.dims().size() - 1}, {0}, {k}, {1}, {});
VH = SliceKernel<T, Context>(
dev_ctx, vh, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
dU = SliceKernel<T, Context>(
dev_ctx, u_grad, {u_grad.dims().size() - 1}, {0}, {k}, {1}, {});
dVH = SliceKernel<T, Context>(
dev_ctx, vh_grad, {vh.dims().size() - 2}, {0}, {k}, {1}, {});
if (u_grad.get_ptr() != nullptr) {
dU = SliceKernel<T, Context>(dev_ctx,
*(u_grad.get_ptr()),
{u.dims().size() - 1},
{0},
{k},
{1},
{});
}
if (vh_grad.get_ptr() != nullptr) {
dVH = SliceKernel<T, Context>(dev_ctx,
*(vh_grad.get_ptr()),
{vh.dims().size() - 2},
{0},
{k},
{1},
{});
}
} else {
U = u;
VH = vh;
dU = u_grad;
dVH = vh_grad;
if (u_grad.get_ptr() != nullptr) {
dU = *(u_grad.get_ptr());
}
if (vh_grad.get_ptr() != nullptr) {
dVH = *(vh_grad.get_ptr());
}
}
auto s_inverse = Pow<T, Context>(dev_ctx, s, -1);
auto s_square = Pow<T, Context>(dev_ctx, s, 2);
......@@ -106,19 +124,17 @@ void SvdGradKernel(const Context& dev_ctx,
F,
Diag<T, Context>(dev_ctx, Infinits<T, Context>(dev_ctx, {k}), 0, 0));
F = Pow<T, Context>(dev_ctx, F, -1);
DenseTensor sigma_term;
DenseTensor u_term;
DenseTensor v_term;
DenseTensor sigma_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
DenseTensor u_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
DenseTensor v_term = Fill<T, Context>(dev_ctx, {1}, 0.0);
// if (ctx.HasInput(framework::GradVarName("S")))
{
const DenseTensor& gS = s_grad;
if (s_grad.get_ptr() != nullptr) {
const DenseTensor& gS = *(s_grad.get_ptr());
sigma_term = Multiply<T, Context>(dev_ctx, Unsqueeze(gS, -2), U);
sigma_term = Matmul<T, Context>(dev_ctx, sigma_term, VH);
}
// if (ctx.HasInput(framework::GradVarName("U"))) {
{
if (u_grad.get_ptr() != nullptr) {
auto UTG = Matmul<T, Context>(dev_ctx, U, dU, true, false);
auto GTU = Matmul<T, Context>(dev_ctx, dU, U, true, false);
u_term = Multiply<T, Context>(
......@@ -141,10 +157,7 @@ void SvdGradKernel(const Context& dev_ctx,
}
u_term = Matmul<T, Context>(dev_ctx, u_term, VH);
}
// }
// if (ctx.HasInput(framework::GradVarName("VH"))) {
{
if (vh_grad.get_ptr() != nullptr) {
auto UTG = Matmul<T, Context>(dev_ctx, VH, dVH, false, true);
auto GTU = Matmul<T, Context>(dev_ctx, dVH, VH, false, true);
v_term = Multiply<T, Context>(
......
......@@ -20,13 +20,13 @@ namespace phi {
template <typename T, typename Context>
void SvdGradKernel(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& U,
const DenseTensor& VH,
const DenseTensor& S,
const DenseTensor& U_grad,
const DenseTensor& VH_grad,
const DenseTensor& S_grad,
const DenseTensor& x,
const DenseTensor& u,
const DenseTensor& vh,
const DenseTensor& s,
const paddle::optional<DenseTensor>& u_grad,
const paddle::optional<DenseTensor>& vh_grad,
const paddle::optional<DenseTensor>& s_grad,
bool full_matrices,
DenseTensor* X_grad);
} // namespace phi
......@@ -29,6 +29,7 @@ class TestSvdOp(OpTest):
def setUp(self):
paddle.enable_static()
self.python_api = paddle.linalg.svd
self.generate_input()
self.generate_output()
self.op_type = "svd"
......@@ -55,7 +56,7 @@ class TestSvdOp(OpTest):
self._output_data = np.linalg.svd(self._input_data)
def test_check_output(self):
self.check_output(no_check_set=['U', 'VH'])
self.check_output(no_check_set=['U', 'VH'], check_eager=True)
def test_svd_forward(self):
""" u matmul diag(s) matmul vt must become X
......@@ -75,13 +76,19 @@ class TestSvdOp(OpTest):
paddle.enable_static()
def check_S_grad(self):
self.check_grad(['X'], ['S'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['S'],
numeric_grad_delta=0.001,
check_eager=True)
def check_U_grad(self):
self.check_grad(['X'], ['U'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['U'],
numeric_grad_delta=0.001,
check_eager=True)
def check_V_grad(self):
self.check_grad(['X'], ['VH'], numeric_grad_delta=0.001)
self.check_grad(['X'], ['VH'],
numeric_grad_delta=0.001,
check_eager=True)
def test_check_grad(self):
"""
......
......@@ -1857,8 +1857,9 @@ def svd(x, full_matrices=False, name=None):
# U * UH == I
# V * VH == I
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_svd(x, full_matrices)
if _in_legacy_dygraph():
return _C_ops.svd(x, 'full_matrices', full_matrices)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svd')
check_type(full_matrices, 'full_matrices', bool, 'svd')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册