diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h index f25fbbb0c698036951c4b9ae8e9ad2778786a1a2..74b961d4e55e8a6ca231285e44bed3e3401461dc 100644 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -16,11 +16,11 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/solve_op.h" -#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/triangular_solve_op.h" #include "paddle/fluid/platform/complex.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" namespace paddle { namespace operators { // namespace operators @@ -59,7 +59,9 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx, framework::Tensor b_bst(bin.type()); TensorExpand(dev_ctx, bin, &b_bst, b_bst_dims_vec); - math::DeviceIndependenceTensorOperations helper(ctx); + auto &phi_dev_ctx = static_cast< + const typename framework::ConvertToPhiContext::TYPE &>( + dev_ctx); // calculate u's conjugate for complex framework::Tensor u_conj(u_bst.type()); @@ -68,7 +70,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx, u_bst.data(), u_bst.numel(), u_conj.mutable_data(u_bst.dims(), dev_ctx.GetPlace())); u_for_range(u_functor); - u_conj = helper.Transpose(u_conj); + u_conj = phi::TransposeLast2Dim(phi_dev_ctx, u_conj); // calculate b's conjugate for complex framework::Tensor b_conj(b_bst.type()); @@ -77,7 +79,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx, b_bst.data(), b_bst.numel(), b_conj.mutable_data(b_bst.dims(), dev_ctx.GetPlace())); b_for_range(b_functor); - b_conj = helper.Transpose(b_conj); + b_conj = phi::TransposeLast2Dim(phi_dev_ctx, b_conj); auto ut_data = u_conj.mutable_data(dev_ctx.GetPlace()); auto uindims = u_bst.dims(); @@ -117,7 +119,7 @@ void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx, out->data(), out->numel(), out->mutable_data(out->dims(), dev_ctx.GetPlace())); out_for_range(out_functor); - *out = helper.Transpose(*out); + *out = phi::TransposeLast2Dim(phi_dev_ctx, *out); } template @@ -145,7 +147,9 @@ class CholeskySolveGradKernel : public framework::OpKernel { auto upper = ctx.Attr("upper"); const auto &dev_ctx = ctx.template device_context(); - math::DeviceIndependenceTensorOperations helper(ctx); + auto &phi_dev_ctx = static_cast< + const typename framework::ConvertToPhiContext::TYPE &>( + dev_ctx); std::vector u_bst_dims_vec; std::vector b_bst_dims_vec; @@ -177,7 +181,7 @@ class CholeskySolveGradKernel : public framework::OpKernel { out->data(), out->numel(), out_conj.mutable_data(out->dims(), dev_ctx.GetPlace())); out_for_range(out_functor); - out_conj = helper.Transpose(out_conj); + out_conj = phi::TransposeLast2Dim(phi_dev_ctx, out_conj); framework::Tensor commonterm(out->type()); auto outdims = out_conj.dims(); @@ -200,7 +204,7 @@ class CholeskySolveGradKernel : public framework::OpKernel { commonterm_conj.mutable_data(commonterm.dims(), dev_ctx.GetPlace())); commonterm_for_range(commonterm_functor); - commonterm_conj = helper.Transpose(commonterm_conj); + commonterm_conj = phi::TransposeLast2Dim(phi_dev_ctx, commonterm_conj); phi::AddRawKernel( static_castset_dims(output_dims); } +void TransposeInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out) { + auto x_dims = x.dims(); + size_t x_rank = x_dims.size(); + size_t axis_size = axis.size(); + + PADDLE_ENFORCE_EQ( + x_rank, + axis_size, + errors::InvalidArgument("The input tensor's dimension " + "should be equal to the axis's size. " + "But received input tensor's dimension is %d, " + "axis's size is %d", + x_rank, + axis_size)); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + PADDLE_ENFORCE_GE( + axis[i], + 0, + errors::InvalidArgument("The axis should be greater than or equal to 0." + "But received %d of axis[%d]", + axis[i], + i)); + + PADDLE_ENFORCE_EQ( + axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, + true, + errors::InvalidArgument( + "Each element of Attribute axis should " + "be a unique value range from 0 to (dims - 1), " + "where the dims is the axis's size, " + "unique value means this axis value can appear only once. " + "But received axis[%d] is %d, axis_size is %d, " + "count[axis[%d]] is %d", + i, + axis[i], + axis_size, + i, + count[axis[i]])); + } + + phi::DDim out_dims(x_dims); + for (size_t i = 0; i < axis_size; ++i) { + out_dims[i] = x_dims[axis[i]]; + } + + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 3c0628981af7c92ff60e8199131b682e4f0f557e..97ec6f7fa582ca388032aec79766c479096f27d2 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -145,4 +145,8 @@ void PixelShuffleInferMeta(const MetaTensor& x, const std::string& data_format, MetaTensor* out); +void TransposeInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/transpose_kernel.h b/paddle/phi/kernels/transpose_kernel.h index 303b4a9a8f05d440dc0a2878574cc951ef5ec1a7..3d89b324bab5b08490457183b7aa31fd4704744b 100644 --- a/paddle/phi/kernels/transpose_kernel.h +++ b/paddle/phi/kernels/transpose_kernel.h @@ -15,7 +15,10 @@ #pragma once #include + #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/empty_kernel.h" namespace phi { @@ -25,4 +28,26 @@ void TransposeKernel(const Context& dev_ctx, const std::vector& axis, DenseTensor* out); +template +DenseTensor Transpose(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis) { + auto dense_out = Empty(dev_ctx); + MetaTensor meta_out(&dense_out); + TransposeInferMeta(x, axis, &meta_out); + TransposeKernel(dev_ctx, x, axis, &dense_out); + return dense_out; +} + +template +DenseTensor TransposeLast2Dim(const Context& dev_ctx, const DenseTensor& x) { + size_t rank = x.dims().size(); + std::vector axis(rank); + for (size_t i = 0; i < rank; ++i) { + axis[i] = i; + } + std::swap(axis[rank - 1], axis[rank - 2]); + return Transpose(dev_ctx, x, axis); +} + } // namespace phi