未验证 提交 28fd30cd 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Remove cholsky solve deps with svd helper (#40119)

* remove cholsky solve deps with svd helper

* fix shape infer bug
上级 5435459a
......@@ -16,11 +16,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle {
namespace operators { // namespace operators
......@@ -59,7 +59,9 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
framework::Tensor b_bst(bin.type());
TensorExpand<T, DeviceContext>(dev_ctx, bin, &b_bst, b_bst_dims_vec);
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(ctx);
auto &phi_dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE &>(
dev_ctx);
// calculate u's conjugate for complex
framework::Tensor u_conj(u_bst.type());
......@@ -68,7 +70,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
u_bst.data<T>(), u_bst.numel(),
u_conj.mutable_data<T>(u_bst.dims(), dev_ctx.GetPlace()));
u_for_range(u_functor);
u_conj = helper.Transpose(u_conj);
u_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, u_conj);
// calculate b's conjugate for complex
framework::Tensor b_conj(b_bst.type());
......@@ -77,7 +79,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
b_bst.data<T>(), b_bst.numel(),
b_conj.mutable_data<T>(b_bst.dims(), dev_ctx.GetPlace()));
b_for_range(b_functor);
b_conj = helper.Transpose(b_conj);
b_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, b_conj);
auto ut_data = u_conj.mutable_data<T>(dev_ctx.GetPlace());
auto uindims = u_bst.dims();
......@@ -117,7 +119,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx,
out->data<T>(), out->numel(),
out->mutable_data<T>(out->dims(), dev_ctx.GetPlace()));
out_for_range(out_functor);
*out = helper.Transpose(*out);
*out = phi::TransposeLast2Dim<T>(phi_dev_ctx, *out);
}
template <typename DeviceContext, typename T>
......@@ -145,7 +147,9 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
auto upper = ctx.Attr<bool>("upper");
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::DeviceIndependenceTensorOperations<DeviceContext, T> helper(ctx);
auto &phi_dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE &>(
dev_ctx);
std::vector<int64_t> u_bst_dims_vec;
std::vector<int64_t> b_bst_dims_vec;
......@@ -177,7 +181,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
out->data<T>(), out->numel(),
out_conj.mutable_data<T>(out->dims(), dev_ctx.GetPlace()));
out_for_range(out_functor);
out_conj = helper.Transpose(out_conj);
out_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, out_conj);
framework::Tensor commonterm(out->type());
auto outdims = out_conj.dims();
......@@ -200,7 +204,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_conj.mutable_data<T>(commonterm.dims(),
dev_ctx.GetPlace()));
commonterm_for_range(commonterm_functor);
commonterm_conj = helper.Transpose(commonterm_conj);
commonterm_conj = phi::TransposeLast2Dim<T>(phi_dev_ctx, commonterm_conj);
phi::AddRawKernel<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
......
......@@ -962,6 +962,59 @@ void PixelShuffleInferMeta(const MetaTensor& x,
out->set_dims(output_dims);
}
void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out) {
auto x_dims = x.dims();
size_t x_rank = x_dims.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(
x_rank,
axis_size,
errors::InvalidArgument("The input tensor's dimension "
"should be equal to the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank,
axis_size));
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE_GE(
axis[i],
0,
errors::InvalidArgument("The axis should be greater than or equal to 0."
"But received %d of axis[%d]",
axis[i],
i));
PADDLE_ENFORCE_EQ(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
true,
errors::InvalidArgument(
"Each element of Attribute axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received axis[%d] is %d, axis_size is %d, "
"count[axis[%d]] is %d",
i,
axis[i],
axis_size,
i,
count[axis[i]]));
}
phi::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; ++i) {
out_dims[i] = x_dims[axis[i]];
}
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
......@@ -145,4 +145,8 @@ void PixelShuffleInferMeta(const MetaTensor& x,
const std::string& data_format,
MetaTensor* out);
void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
} // namespace phi
......@@ -15,7 +15,10 @@
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace phi {
......@@ -25,4 +28,26 @@ void TransposeKernel(const Context& dev_ctx,
const std::vector<int>& axis,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Transpose(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis) {
auto dense_out = Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
TransposeInferMeta(x, axis, &meta_out);
TransposeKernel<T, Context>(dev_ctx, x, axis, &dense_out);
return dense_out;
}
template <typename T, typename Context>
DenseTensor TransposeLast2Dim(const Context& dev_ctx, const DenseTensor& x) {
size_t rank = x.dims().size();
std::vector<int> axis(rank);
for (size_t i = 0; i < rank; ++i) {
axis[i] = i;
}
std::swap(axis[rank - 1], axis[rank - 2]);
return Transpose<T, Context>(dev_ctx, x, axis);
}
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册