diff --git a/paddle/fluid/operators/solve_op.cc b/paddle/fluid/operators/solve_op.cc index a7bf413e10519ab6dd11e3b81005c0aee4cb37b8..daa020e4a0d744c2c8d5811bce9888f5696fc8a0 100644 --- a/paddle/fluid/operators/solve_op.cc +++ b/paddle/fluid/operators/solve_op.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/solve_op.h" - #include #include #include #include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/ddim.h" namespace paddle { @@ -220,10 +220,3 @@ REGISTER_OPERATOR(solve, ops::SolveOpGradMaker); REGISTER_OPERATOR(solve_grad, ops::SolveGradOp); - -REGISTER_OP_CPU_KERNEL(solve, - ops::SolveKernel, - ops::SolveKernel); -REGISTER_OP_CPU_KERNEL(solve_grad, - ops::SolveGradKernel, - ops::SolveGradKernel); diff --git a/paddle/fluid/operators/solve_op.cu b/paddle/fluid/operators/solve_op.cu deleted file mode 100644 index a1e56fab5702b66f90ab2d7790a9abfe0786c996..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/solve_op.cu +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/solve_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(solve, - ops::SolveKernel, - ops::SolveKernel); - -REGISTER_OP_CUDA_KERNEL(solve_grad, - ops::SolveGradKernel, - ops::SolveGradKernel); diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h deleted file mode 100644 index 115223749431bfadfac37a09871b738b637689fa..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/solve_op.h +++ /dev/null @@ -1,661 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "Eigen/Core" -#include "Eigen/LU" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" -#include "paddle/fluid/operators/squeeze_op.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/matrix_solve.h" -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#endif - -#define MAX_RANK_SUPPORTED 6 - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using framework::To32BitIndex; - -constexpr int kMULMKLDNNINT8 = 1; - -template -void ReduceSumForSolve(const Tensor* input, - Tensor* output, - const std::vector& reduce_dims, - bool keep_dim, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - auto stream = ctx.cuda_device_context().stream(); - TensorReduceImpl>( - ctx.cuda_device_context(), - *input, - output, - kps::IdentityFunctor(), - reduce_dims, - stream); -#else - ReduceKernelFunctor( - input, output, reduce_dims, keep_dim, false, ctx) - .template apply(); -#endif -} - -// check the input other is vector_case or not -static inline bool is_vector_rhs(const Tensor& input, const Tensor& other) { - auto x_dim = input.dims(); - auto y_dim = other.dims(); - auto x_dim_size = x_dim.size(); - auto y_dim_size = y_dim.size(); - std::vector x_dims_vec = phi::vectorize(x_dim); - std::vector y_dims_vec = phi::vectorize(y_dim); - - std::vector::const_iterator f = x_dims_vec.begin(); - std::vector::const_iterator l = x_dims_vec.end() - 1; - std::vector x_dims_vec_cut(f, l); // input.shape[:-1] - - std::vector expected_batched_rhs_shape(x_dims_vec_cut); - bool vector_case = - y_dim_size == 1 || (x_dim_size - 1 == y_dim_size && - y_dims_vec == (expected_batched_rhs_shape)); - - return vector_case; -} - -// unsqueeze operation helper -static framework::DDim GetOutputShapeUnsqueeze( - const std::vector unsqz_dims, const framework::DDim& in_dims) { - int output_size = in_dims.size() + static_cast(unsqz_dims.size()); - int cur_output_size = in_dims.size(); - std::vector output_shape(output_size, 0); - - // Validity Check: rank range. - PADDLE_ENFORCE_LE(output_size, - 6, - platform::errors::InvalidArgument( - "The output " - "tensor's rank should be less than 6.")); - - for (int axis : unsqz_dims) { - int cur = axis < 0 ? axis + cur_output_size + 1 : axis; - // Vaildity Check: the axis bound - PADDLE_ENFORCE_GE( - cur, - 0, - platform::errors::InvalidArgument("The insert dimension value should " - "not be less than 0")); - PADDLE_ENFORCE_LE(cur, - cur_output_size, - platform::errors::InvalidArgument( - "The insert dimension value shoule not be larger " - "than the dimension size of input tensor")); - // Move old axis, and insert new axis - for (int i = cur_output_size; i >= cur; --i) { - if (output_shape[i] == 1) { - // Move axis - output_shape[i + 1] = 1; - output_shape[i] = 0; - } - } - output_shape[cur] = 1; - // Add the output size. - cur_output_size++; - } - - // Make output shape - for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { - if (output_shape[out_idx] == 0) { - output_shape[out_idx] = in_dims[in_idx++]; - } - } - - return phi::make_ddim(output_shape); -} - -// operation like squeeze(-1) -static void to_squeeze(const framework::ExecutionContext& context, - const framework::Tensor& in, - framework::Tensor* out) { - auto x_dims = in.dims(); - std::vector sqz_dims = {-1}; - auto out_dims = GetOutputShape(sqz_dims, x_dims, true); - out->mutable_data(context.GetPlace(), in.type()); - framework::TensorCopy( - in, - context.GetPlace(), - context.template device_context(), - out); - out->Resize(out_dims); -} - -// vector_case, need to operate like unsqueeze(-1) -static void to_unsqueeze(const framework::ExecutionContext& context, - const framework::Tensor& in, - framework::Tensor* out) { - auto x_dims = in.dims(); - std::vector unsqz_dims = {-1}; - framework::DDim out_dims = out->dims(); - out_dims = GetOutputShapeUnsqueeze(unsqz_dims, x_dims); - framework::TensorCopy( - in, - context.GetPlace(), - context.template device_context(), - out); - out->Resize(out_dims); -} - -// Prepared for the broadcast operation -static std::vector get_broadcast_batch_portion( - std::vector x, std::vector y) { - size_t size_x = x.size(); - size_t size_y = y.size(); - size_t size = std::max(size_x, size_y); - std::vector batchPortion(size); - - ptrdiff_t i = (ptrdiff_t)size - 1; - for (; i >= 0; --i) { - ptrdiff_t offset = size - i - 1; - ptrdiff_t dim_x = size_x - offset - 1; - ptrdiff_t dim_y = size_y - offset - 1; - int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; - int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; - - PADDLE_ENFORCE_EQ( - (x_size == y_size || x_size == 1 || y_size == 1), - true, - platform::errors::PreconditionNotMet( - "The size of tensor x (%d) must match the size of tensor y " - "(%d) at non-singleton dimension %d.", - x_size, - y_size, - i)); - - batchPortion[i] = x_size != 1 ? x_size : y_size; - } - return batchPortion; -} - -// broadcast the batch dimensions of tensor x and tensor y. -static inline std::tuple, std::vector> -get_broadcast_dims(const Tensor& x, const Tensor& y) { - std::vector x_dims_vec = phi::vectorize(x.dims()); - std::vector y_dims_vec = phi::vectorize(y.dims()); - - std::vector::const_iterator f1 = x_dims_vec.begin(); - std::vector::const_iterator l1 = x_dims_vec.end() - 2; - std::vector x_dims_vec_cut(f1, l1); - - std::vector::const_iterator f2 = y_dims_vec.begin(); - std::vector::const_iterator l2 = y_dims_vec.end() - 2; - std::vector y_dims_vec_cut(f2, l2); - - std::vector expand_batch_portion = - get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); - - std::vector x_expand_size({expand_batch_portion}); - x_expand_size.insert(x_expand_size.end(), - {x_dims_vec[static_cast(x_dims_vec.size()) - 2], - x_dims_vec[static_cast(x_dims_vec.size()) - 1]}); - - std::vector y_expand_size({expand_batch_portion}); - y_expand_size.insert(y_expand_size.end(), - {y_dims_vec[static_cast(y_dims_vec.size()) - 2], - y_dims_vec[static_cast(y_dims_vec.size()) - 1]}); - - return std::make_tuple(x_expand_size, y_expand_size); -} - -template -void expand_impl(const DeviceContext& context, - const Tensor& in, - Tensor* out, - const std::vector& expand_shape) { - auto vec_in_dims = phi::vectorize(in.dims()); - auto diff = expand_shape.size() - vec_in_dims.size(); - vec_in_dims.insert(vec_in_dims.begin(), diff, 1); - std::vector repeat_times(vec_in_dims.size()); - - for (size_t i = 0; i < vec_in_dims.size(); ++i) { - PADDLE_ENFORCE_NE( - expand_shape[i], - 0, - platform::errors::InvalidArgument("The expanded size cannot be zero.")); - if (i < diff) { - PADDLE_ENFORCE_GT( - expand_shape[i], - 0, - platform::errors::InvalidArgument( - "The expanded size (%d) for non-existing dimensions must be " - "positive for expand operation.", - expand_shape[i])); - repeat_times[i] = expand_shape[i]; - } else if (expand_shape[i] > 0) { - if (vec_in_dims[i] != 1) { - PADDLE_ENFORCE_EQ( - vec_in_dims[i], - expand_shape[i], - platform::errors::InvalidArgument( - "The value (%d) of the non-singleton dimension does not match" - " the corresponding value (%d) in shape for expand operation.", - vec_in_dims[i], - expand_shape[i])); - repeat_times[i] = 1; - } else { - repeat_times[i] = expand_shape[i]; - } - } else { - PADDLE_ENFORCE_EQ( - expand_shape[i], - -1, - platform::errors::InvalidArgument( - "When the value in shape is negative for expand_v2 op, " - "only -1 is supported, but the value received is %d.", - expand_shape[i])); - repeat_times[i] = 1; - } - } - - Eigen::DSizes bcast_dims; - for (size_t i = 0; i < repeat_times.size(); ++i) { - bcast_dims[i] = repeat_times[i]; - } - - framework::DDim new_in_dims = phi::make_ddim(vec_in_dims); - framework::DDim out_dims(new_in_dims); - for (size_t i = 0; i < repeat_times.size(); ++i) { - out_dims[i] *= repeat_times[i]; - } - - out->Resize(out_dims); - out->mutable_data(context.GetPlace()); - auto x = EigenTensor::From(in, new_in_dims); - auto y = EigenTensor::From(*out, out_dims); - auto& place = *context.eigen_device(); - // use 32-bit index to speed up - bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); - if (use_32bit_index) { - EigenBroadcast, T, Rank>::Eval( - place, To32BitIndex(y), To32BitIndex(x), bcast_dims); - } else { - EigenBroadcast, T, Rank>::Eval( - place, y, x, bcast_dims); - } -} - -template -void TensorExpand(const DeviceContext& context, - const Tensor& in, - Tensor* out, - const std::vector& expand_shape) { - // necessary check before expand operation - PADDLE_ENFORCE_GE(expand_shape.size(), - in.dims().size(), - platform::errors::InvalidArgument( - "The size of 'expand_shape' (%d) should >= the input " - "Tensor's rank (%d).", - expand_shape.size(), - in.dims().size())); - PADDLE_ENFORCE_LE(expand_shape.size(), - MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The size of 'expand_shape' (%d) should be <= %d", - expand_shape.size(), - MAX_RANK_SUPPORTED)); - switch (expand_shape.size()) { - case 1: - expand_impl<1, T, DeviceContext>(context, in, out, expand_shape); - break; - case 2: - expand_impl<2, T, DeviceContext>(context, in, out, expand_shape); - break; - case 3: - expand_impl<3, T, DeviceContext>(context, in, out, expand_shape); - break; - case 4: - expand_impl<4, T, DeviceContext>(context, in, out, expand_shape); - break; - case 5: - expand_impl<5, T, DeviceContext>(context, in, out, expand_shape); - break; - case 6: - expand_impl<6, T, DeviceContext>(context, in, out, expand_shape); - break; - } -} - -template -static void linalg_solve(const framework::ExecutionContext& context, - const framework::Tensor* x, - const framework::Tensor* y, - framework::Tensor* out) { - out->mutable_data(context.GetPlace()); - - auto& dev_ctx = context.template device_context(); - phi::funcs::MatrixSolveFunctor mat_solve; - - // input y can be vector or matrix - // but need to be unsqueezed if y is a vector - bool is_vector = false; - is_vector = is_vector_rhs(*x, *y); - - Tensor tmp_y; - if (is_vector) { - tmp_y.mutable_data(context.GetPlace(), y->dtype()); - to_unsqueeze(context, *y, &tmp_y); - } else { - tmp_y.Resize(y->dims()); - tmp_y.mutable_data(context.GetPlace(), y->dtype()); - framework::TensorCopy( - *y, - context.GetPlace(), - context.template device_context(), - &tmp_y); - } - - Tensor tmp_x; - tmp_x.Resize(x->dims()); - tmp_x.mutable_data(context.GetPlace(), x->dtype()); - framework::TensorCopy( - *x, - context.GetPlace(), - context.template device_context(), - &tmp_x); - - std::vector x_broadcast_dims; - std::vector y_broadcast_dims; - std::tie(x_broadcast_dims, y_broadcast_dims) = - get_broadcast_dims(tmp_x, tmp_y); - - Tensor tmp_x_bc; - TensorExpand(dev_ctx, tmp_x, &tmp_x_bc, x_broadcast_dims); - - Tensor tmp_y_bc; - TensorExpand(dev_ctx, tmp_y, &tmp_y_bc, y_broadcast_dims); - - auto x_dim = x->dims(); - auto y_dim = y->dims(); - auto x_dim_size = x_dim.size(); - auto y_dim_size = y_dim.size(); - - if (is_vector) { // vector case - out->Resize(tmp_y_bc.dims()); // out.unsqueeze(-1) - mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); - - Tensor out_tmp; - out_tmp.Resize(out->dims()); - out_tmp = *out; - to_squeeze(context, out_tmp, out); // out.squeeze(-1) - } else { - PADDLE_ENFORCE_EQ( - x_dim[x_dim_size - 1], - y_dim[y_dim_size - 2], - platform::errors::InvalidArgument( - "Matrix X1 with dimension greater than 2 and any matrix Y1," - "the matrix X1's width must be equal with matrix Y1's " - "height. But received X's shape = [%s], X1's shape = [%s], X1's " - "width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = " - "%s.", - x_dim, - x_dim, - x_dim[x_dim_size - 1], - y_dim, - y_dim, - y_dim[y_dim_size - 2])); - mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); - } -} - -template -class SolveKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* x = context.Input("X"); - const auto* y = context.Input("Y"); - Tensor* out = context.Output("Out"); - linalg_solve(context, x, y, out); - } -}; - -template -class SolveGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - // reuse the linalg.solve forward output - auto* out = ctx.Input("Out"); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - bool is_vector = false; - is_vector = is_vector_rhs(*input, *y); - - Tensor tmp_y; - if (is_vector) { - tmp_y.mutable_data(ctx.GetPlace(), y->dtype()); - to_unsqueeze(ctx, *y, &tmp_y); - } else { - tmp_y.Resize(y->dims()); - tmp_y.mutable_data(ctx.GetPlace(), y->dtype()); - framework::TensorCopy( - *y, - ctx.GetPlace(), - ctx.template device_context(), - &tmp_y); - } - - Tensor tmp_x; - tmp_x.Resize(input->dims()); - tmp_x.mutable_data(ctx.GetPlace(), input->dtype()); - framework::TensorCopy( - *input, - ctx.GetPlace(), - ctx.template device_context(), - &tmp_x); - - std::vector x_broadcast_dims; - std::vector y_broadcast_dims; - std::tie(x_broadcast_dims, y_broadcast_dims) = - get_broadcast_dims(tmp_x, tmp_y); - - // tmp_dx - Tensor tmp_dx; - tmp_dx.Resize(phi::make_ddim(x_broadcast_dims)); - tmp_dx.mutable_data(ctx.GetPlace()); - - // tmp_dy - Tensor tmp_dy; - tmp_dy.Resize(phi::make_ddim(y_broadcast_dims)); - tmp_dy.mutable_data(ctx.GetPlace()); - - Tensor tmp_input(input->dtype()); - const auto& new_dims_vec = phi::funcs::getNewDimsVec(input->dims()); - tmp_input.Resize(phi::make_ddim(new_dims_vec)); - tmp_input.mutable_data(ctx.GetPlace()); - phi::funcs::TransposeNormal trans; - std::vector new_axis = phi::funcs::getNewAxis(input->dims().size()); - auto& dev_ctx = ctx.template device_context(); - trans(dev_ctx, *input, &tmp_input, new_axis); - - if (dy) { - dy->mutable_data(ctx.GetPlace()); - // reuse linalg_solve forward logics to get tmp_dy - linalg_solve(ctx, &tmp_input, dout, &tmp_dy); - } - - if (dx) { - dx->mutable_data(ctx.GetPlace()); - // to get dx - auto blas = phi::funcs::GetBlas(ctx); - if (input->dims().size() == 2 && y->dims().size() == 2) { - auto mat_dim_a1 = - phi::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); - auto mat_dim_b1 = - phi::funcs::CreateMatrixDescriptor(out->dims(), 0, true); - blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); - } else if (is_vector_rhs(*input, *y)) { - Tensor tmp_dy_; - tmp_dy_.mutable_data(ctx.GetPlace(), y->dtype()); - to_unsqueeze(ctx, tmp_dy, &tmp_dy_); - - Tensor tmp_out_; - tmp_out_.mutable_data(ctx.GetPlace(), out->dtype()); - to_unsqueeze(ctx, *out, &tmp_out_); - - auto mat_dim_a1 = - phi::funcs::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false); - auto mat_dim_b1 = - phi::funcs::CreateMatrixDescriptor(tmp_out_.dims(), 0, true); - blas.MatMul( - tmp_dy_, mat_dim_a1, tmp_out_, mat_dim_b1, T(-1), &tmp_dx, T(0)); - } else { - auto mat_dim_a1 = - phi::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); - auto mat_dim_b1 = - phi::funcs::CreateMatrixDescriptor(out->dims(), 0, true); - blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); - } - } - - if (y->dims() != tmp_dy.dims()) { - Tensor dy_help; - dy_help.Resize(tmp_dy.dims()); - dy_help.mutable_data(ctx.GetPlace(), tmp_dy.dtype()); - framework::TensorCopy( - tmp_dy, - ctx.GetPlace(), - ctx.template device_context(), - &dy_help); - - // get dims - std::vector x_dims = vectorize(input->dims()); - std::vector y_dims = vectorize(y->dims()); - std::vector dout_dims = vectorize(dout->dims()); - - if (is_vector_rhs(*input, *y)) { - dout_dims.push_back(1); - } - - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - const std::vector dy_help_dims = vectorize(dy_help.dims()); - std::vector dy_broadcast_dims(ndim); - - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, - 1); - std::copy(y_dims.data(), - y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - bool keep_dim = true; - if (dy_help.dims().size() != dy->dims().size()) { - keep_dim = false; - } - ReduceSumForSolve( - &dy_help, dy, dy_reduce_dims, keep_dim, ctx); - } - dy->Resize(y->dims()); - } - } else { - framework::TensorCopy( - tmp_dy, - ctx.GetPlace(), - ctx.template device_context(), - dy); - } - - if (input->dims() != tmp_dx.dims()) { - Tensor dx_help; - dx_help.Resize(tmp_dx.dims()); - dx_help.mutable_data(ctx.GetPlace(), tmp_dx.dtype()); - framework::TensorCopy( - tmp_dx, - ctx.GetPlace(), - ctx.template device_context(), - &dx_help); - - // get dims - std::vector x_dims = vectorize(input->dims()); - std::vector y_dims = vectorize(y->dims()); - - int x_ndim = x_dims.size(); - int ndim = x_broadcast_dims.size(); - - const std::vector dx_help_dims = vectorize(dx_help.dims()); - std::vector dx_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, - 1); - std::copy(x_dims.data(), - x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - - std::vector dx_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dx) { - dx->mutable_data(ctx.GetPlace()); - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - bool keep_dim = true; - if (dx_help.dims().size() != dx->dims().size()) { - keep_dim = false; - } - ReduceSumForSolve( - &dx_help, dx, dx_reduce_dims, keep_dim, ctx); - } - dx->Resize(input->dims()); - } - } else { - framework::TensorCopy( - tmp_dx, - ctx.GetPlace(), - ctx.template device_context(), - dx); - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 02f812f9b17c0250647fd40156836cd0220c792f..5958f0e71e76a96717c8d863649c2237cbbfdd90 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3211,15 +3211,18 @@ void UnsqueezeInferMeta(const MetaTensor& x, } out->set_dtype(x.dtype()); } - // set xshape dims. - std::vector xshape_dims(x_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < x_dims.size(); ++i) { - xshape_dims[i + 1] = x_dims[i]; + if (xshape) { + // set xshape dims. + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < x_dims.size(); ++i) { + xshape_dims[i + 1] = x_dims[i]; + } + + xshape->set_dims(phi::make_ddim(xshape_dims)); + xshape->share_lod(x); + xshape->set_dtype(x.dtype()); } - xshape->set_dims(phi::make_ddim(xshape_dims)); - xshape->share_lod(x); - xshape->set_dtype(x.dtype()); } void UnStackInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 455d42b548606de57d6adceeda267500dea19033..05abcbd0d1964917099b33bf3aa35ce7034ec831 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -62,6 +62,7 @@ set(COMMON_KERNEL_DEPS pooling maxouting matrix_inverse + matrix_solve phi_dynload_warpctc sequence_padding sequence_scale) diff --git a/paddle/phi/kernels/cpu/solve_grad_kernel.cc b/paddle/phi/kernels/cpu/solve_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b11d49259fd6466dd5b67ed5d78f4cedd67469c --- /dev/null +++ b/paddle/phi/kernels/cpu/solve_grad_kernel.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/solve_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/solve_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + solve_grad, CPU, ALL_LAYOUT, phi::SolveGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/solve_kernel.cc b/paddle/phi/kernels/cpu/solve_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..bde049bcc3ec092b01e5b25309e359024d2d4e43 --- /dev/null +++ b/paddle/phi/kernels/cpu/solve_kernel.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/solve_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/solve_kernel_impl.h" + +PD_REGISTER_KERNEL(solve, CPU, ALL_LAYOUT, phi::SolveKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/solve_grad_kernel.cu b/paddle/phi/kernels/gpu/solve_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c13c3b6545c448bfb33d44d1c24864e5089db993 --- /dev/null +++ b/paddle/phi/kernels/gpu/solve_grad_kernel.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/solve_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/solve_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + solve_grad, GPU, ALL_LAYOUT, phi::SolveGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/solve_kernel.cu b/paddle/phi/kernels/gpu/solve_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..59bc77ca0b975764f49c06b544d04191db76b646 --- /dev/null +++ b/paddle/phi/kernels/gpu/solve_kernel.cu @@ -0,0 +1,19 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/solve_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/solve_kernel_impl.h" + +PD_REGISTER_KERNEL(solve, GPU, ALL_LAYOUT, phi::SolveKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/solve_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..55ee023cb5caa617db498554796ad9b717075a7a --- /dev/null +++ b/paddle/phi/kernels/impl/solve_grad_kernel_impl.h @@ -0,0 +1,267 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/expand_as_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/matrix_solve.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/impl/solve_kernel_impl.h" +#include "paddle/phi/kernels/squeeze_kernel.h" +#include "paddle/phi/kernels/unsqueeze_kernel.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/gpu/reduce.h" +#endif + +namespace phi { + +template +struct ReduceSumForSolvelGrad { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims, + bool keep_dims); +}; + +template +struct ReduceSumForSolvelGrad { + void operator()(const CPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims, + bool keep_dims) { + std::vector reduce_dims_tmp(reduce_dims.begin(), + reduce_dims.end()); + phi::ReduceKernelImpl( + dev_ctx, input, output, reduce_dims_tmp, keep_dims, false); + } +}; + +#if defined(__NVCC__) || defined(__HIPCC__) +template +struct ReduceSumForSolvelGrad { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims, + bool keep_dims) { + phi::funcs::ReduceKernel>( + dev_ctx, input, output, kps::IdentityFunctor(), reduce_dims); + } +}; +#endif + +template +void SolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& out, + DenseTensor* dx, + DenseTensor* dy) { + bool is_vector = false; + is_vector = is_vector_rhs(x, y); + DenseTensor tmp_y; + if (is_vector) { + dev_ctx.Alloc(&tmp_y, y.dtype()); + phi::Unsqueeze(dev_ctx, y, {-1}, &tmp_y, nullptr); + } else { + tmp_y.Resize(y.dims()); + dev_ctx.Alloc(&tmp_y, y.dtype()); + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, &tmp_y); + } + DenseTensor tmp_x; + tmp_x.Resize(x.dims()); + dev_ctx.Alloc(&tmp_x, x.dtype()); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &tmp_x); + + std::vector x_broadcast_dims; + std::vector y_broadcast_dims; + std::tie(x_broadcast_dims, y_broadcast_dims) = + get_broadcast_dims(tmp_x, tmp_y); + // tmp_dx + DenseTensor tmp_dx; + tmp_dx.Resize(phi::make_ddim(x_broadcast_dims)); + dev_ctx.template Alloc(&tmp_dx); + + // tmp_dy + DenseTensor tmp_dy; + tmp_dy.Resize(phi::make_ddim(y_broadcast_dims)); + dev_ctx.template Alloc(&tmp_dy); + + DenseTensor tmp_input(x.dtype()); + const auto& new_dims_vec = phi::funcs::getNewDimsVec(x.dims()); + tmp_input.Resize(phi::make_ddim(new_dims_vec)); + dev_ctx.template Alloc(&tmp_input); + + phi::funcs::TransposeNormal trans; + std::vector new_axis = phi::funcs::getNewAxis(x.dims().size()); + trans(dev_ctx, x, &tmp_input, new_axis); + + if (dy) { + dev_ctx.template Alloc(dy); + linalg_solve(dev_ctx, tmp_input, dout, &tmp_dy); + } + + if (dx) { + dev_ctx.template Alloc(dx); + + // to get dx + auto blas = phi::funcs::GetBlas(dev_ctx); + if (x.dims().size() == 2 && y.dims().size() == 2) { + auto mat_dim_a1 = + phi::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); + auto mat_dim_b1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + blas.MatMul(tmp_dy, mat_dim_a1, out, mat_dim_b1, T(-1), &tmp_dx, T(0)); + + } else if (is_vector_rhs(x, y)) { + DenseTensor tmp_dy_; + dev_ctx.Alloc(&tmp_dy_, y.dtype()); + + phi::Unsqueeze(dev_ctx, + tmp_dy, + paddle::experimental::IntArray({-1}), + &tmp_dy_, + nullptr); + + DenseTensor tmp_out_; + dev_ctx.Alloc(&tmp_out_, out.dtype()); + + phi::Unsqueeze(dev_ctx, + out, + paddle::experimental::IntArray({-1}), + &tmp_out_, + nullptr); + + auto mat_dim_a1 = + phi::funcs::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false); + auto mat_dim_b1 = + phi::funcs::CreateMatrixDescriptor(tmp_out_.dims(), 0, true); + blas.MatMul( + tmp_dy_, mat_dim_a1, tmp_out_, mat_dim_b1, T(-1), &tmp_dx, T(0)); + + } else { + auto mat_dim_a1 = + phi::funcs::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); + auto mat_dim_b1 = phi::funcs::CreateMatrixDescriptor(out.dims(), 0, true); + blas.MatMul(tmp_dy, mat_dim_a1, out, mat_dim_b1, T(-1), &tmp_dx, T(0)); + } + } + if (y.dims() != tmp_dy.dims()) { + DenseTensor dy_help; + dy_help.Resize(tmp_dy.dims()); + dev_ctx.Alloc(&dy_help, tmp_dy.dtype()); + + phi::Copy(dev_ctx, tmp_dy, dev_ctx.GetPlace(), false, &dy_help); + + // get dims + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + if (is_vector_rhs(x, y)) { + dout_dims.push_back(1); + } + + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + const std::vector dy_help_dims = vectorize(dy_help.dims()); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + bool keep_dim = true; + if (dy_help.dims().size() != dy->dims().size()) { + keep_dim = false; + } + ReduceSumForSolvelGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims, keep_dim); + } + dy->Resize(y.dims()); + } + } else { + phi::Copy(dev_ctx, tmp_dy, dev_ctx.GetPlace(), false, dy); + } + + if (x.dims() != tmp_dx.dims()) { + DenseTensor dx_help; + dx_help.Resize(tmp_dx.dims()); + dev_ctx.Alloc(&dx_help, tmp_dx.dtype()); + phi::Copy(dev_ctx, tmp_dx, dev_ctx.GetPlace(), false, &dx_help); + // get dims + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + + int x_ndim = x_dims.size(); + int ndim = x_broadcast_dims.size(); + + const std::vector dx_help_dims = vectorize(dx_help.dims()); + std::vector dx_broadcast_dims(ndim); + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + + std::vector dx_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dx) { + dev_ctx.template Alloc(dx); + + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + bool keep_dim = true; + if (dx_help.dims().size() != dx->dims().size()) { + keep_dim = false; + } + ReduceSumForSolvelGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims, keep_dim); + } + dx->Resize(x.dims()); + } + } else { + phi::Copy(dev_ctx, tmp_dx, dev_ctx.GetPlace(), false, dx); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/solve_kernel_impl.h b/paddle/phi/kernels/impl/solve_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..09c9e74dd207a2d351445ab1cf1c1c270d2645d0 --- /dev/null +++ b/paddle/phi/kernels/impl/solve_kernel_impl.h @@ -0,0 +1,199 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/expand_as_kernel.h" +#include "paddle/phi/kernels/funcs/matrix_solve.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" +#include "paddle/phi/kernels/squeeze_kernel.h" +#include "paddle/phi/kernels/unsqueeze_kernel.h" + +namespace phi { + +using Tensor = DenseTensor; + +// check the input other is vector_case or not +static inline bool is_vector_rhs(const DenseTensor& input, + const DenseTensor& other) { + auto x_dim = input.dims(); + auto y_dim = other.dims(); + auto x_dim_size = x_dim.size(); + auto y_dim_size = y_dim.size(); + std::vector x_dims_vec = phi::vectorize(x_dim); + std::vector y_dims_vec = phi::vectorize(y_dim); + + std::vector::const_iterator f = x_dims_vec.begin(); + std::vector::const_iterator l = x_dims_vec.end() - 1; + std::vector x_dims_vec_cut(f, l); // input.shape[:-1] + + std::vector expected_batched_rhs_shape(x_dims_vec_cut); + bool vector_case = + y_dim_size == 1 || (x_dim_size - 1 == y_dim_size && + y_dims_vec == (expected_batched_rhs_shape)); + + return vector_case; +} + +// Prepared for the broadcast operation +static std::vector get_broadcast_batch_portion( + std::vector x, std::vector y) { + size_t size_x = x.size(); + size_t size_y = y.size(); + size_t size = std::max(size_x, size_y); + std::vector batchPortion(size); + ptrdiff_t i = (ptrdiff_t)size - 1; + for (; i >= 0; --i) { + ptrdiff_t offset = size - i - 1; + ptrdiff_t dim_x = size_x - offset - 1; + ptrdiff_t dim_y = size_y - offset - 1; + int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; + int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; + PADDLE_ENFORCE_EQ( + (x_size == y_size || x_size == 1 || y_size == 1), + true, + phi::errors::PreconditionNotMet( + "The size of tensor x (%d) must match the size of tensor y " + "(%d) at non-singleton dimension %d.", + x_size, + y_size, + i)); + + batchPortion[i] = x_size != 1 ? x_size : y_size; + } + return batchPortion; +} + +static inline std::vector convert_to_int_vec(std::vector a) { + std::vector ret; + for (size_t i = 0; i < a.size(); i++) { + ret.emplace_back(int(a[i])); + } + + return ret; +} + +// broadcast the batch dimensions of tensor x and tensor y. +static inline std::tuple, std::vector> +get_broadcast_dims(const Tensor& x, const Tensor& y) { + std::vector x_dims_vec = phi::vectorize(x.dims()); + std::vector y_dims_vec = phi::vectorize(y.dims()); + std::vector::const_iterator f1 = x_dims_vec.begin(); + std::vector::const_iterator l1 = x_dims_vec.end() - 2; + std::vector x_dims_vec_cut(f1, l1); + + std::vector::const_iterator f2 = y_dims_vec.begin(); + std::vector::const_iterator l2 = y_dims_vec.end() - 2; + std::vector y_dims_vec_cut(f2, l2); + + std::vector expand_batch_portion = + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); + std::vector x_expand_size({expand_batch_portion}); + x_expand_size.insert(x_expand_size.end(), + {x_dims_vec[static_cast(x_dims_vec.size()) - 2], + x_dims_vec[static_cast(x_dims_vec.size()) - 1]}); + std::vector y_expand_size({expand_batch_portion}); + y_expand_size.insert(y_expand_size.end(), + {y_dims_vec[static_cast(y_dims_vec.size()) - 2], + y_dims_vec[static_cast(y_dims_vec.size()) - 1]}); + + return std::make_tuple(x_expand_size, y_expand_size); +} + +template +static void linalg_solve(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + dev_ctx.template Alloc(out); + phi::funcs::MatrixSolveFunctor mat_solve; + + // input y can be vector or matrix + // but need to be unsqueezed if y is a vector + bool is_vector = false; + is_vector = is_vector_rhs(x, y); + + Tensor tmp_y; + if (is_vector) { + dev_ctx.Alloc(&tmp_y, y.dtype()); + + phi::Unsqueeze(dev_ctx, y, {-1}, &tmp_y, nullptr); + } else { + tmp_y.Resize(y.dims()); + dev_ctx.Alloc(&tmp_y, y.dtype()); + + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, &tmp_y); + } + + Tensor tmp_x; + tmp_x.Resize(x.dims()); + dev_ctx.Alloc(&tmp_x, x.dtype()); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &tmp_x); + + std::vector x_broadcast_dims; + std::vector y_broadcast_dims; + std::tie(x_broadcast_dims, y_broadcast_dims) = + get_broadcast_dims(tmp_x, tmp_y); + + Tensor tmp_x_bc; + + phi::ExpandAsKernel( + dev_ctx, tmp_x, nullptr, convert_to_int_vec(x_broadcast_dims), &tmp_x_bc); + + Tensor tmp_y_bc; + phi::ExpandAsKernel( + dev_ctx, tmp_y, nullptr, convert_to_int_vec(y_broadcast_dims), &tmp_y_bc); + + auto x_dim = x.dims(); + auto y_dim = y.dims(); + auto x_dim_size = x_dim.size(); + auto y_dim_size = y_dim.size(); + + if (is_vector) { // vector case + out->Resize(tmp_y_bc.dims()); // out.unsqueeze(-1) + mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); + + Tensor out_tmp; + out_tmp.Resize(out->dims()); + out_tmp = *out; + + phi::SqueezeKernel(dev_ctx, out_tmp, {-1}, out, nullptr); + } else { + PADDLE_ENFORCE_EQ( + x_dim[x_dim_size - 1], + y_dim[y_dim_size - 2], + phi::errors::InvalidArgument( + "Matrix X1 with dimension greater than 2 and any matrix Y1," + "the matrix X1's width must be equal with matrix Y1's " + "height. But received X's shape = [%s], X1's shape = [%s], X1's " + "width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = " + "%s.", + x_dim, + x_dim, + x_dim[x_dim_size - 1], + y_dim, + y_dim, + y_dim[y_dim_size - 2])); + mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); + } +} + +template +void SolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + linalg_solve(dev_ctx, x, y, out); +} + +} // namespace phi diff --git a/paddle/phi/kernels/solve_grad_kernel.h b/paddle/phi/kernels/solve_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..31bdb9932becccf400bc324b16f9dc5d10079c5c --- /dev/null +++ b/paddle/phi/kernels/solve_grad_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& out, + DenseTensor* dx, + DenseTensor* dy); + +} // namespace phi diff --git a/paddle/phi/kernels/solve_kernel.h b/paddle/phi/kernels/solve_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..28dddb0f641bdf4108e334a620afaf6f3866553a --- /dev/null +++ b/paddle/phi/kernels/solve_kernel.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/unsqueeze_kernel.h b/paddle/phi/kernels/unsqueeze_kernel.h index 4622a9b0a859c913ee8c8899d1d12a0a7f988f73..62ba878c056cb6ccc55b4f2ecd98c2e45aa359bd 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.h +++ b/paddle/phi/kernels/unsqueeze_kernel.h @@ -17,6 +17,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -26,4 +27,16 @@ void UnsqueezeKernel(const Context& dev_ctx, const IntArray& axes, DenseTensor* out, DenseTensor* xshape); + +template +void Unsqueeze(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape) { + MetaTensor meta_out(out); + UnsqueezeInferMeta(x, axes, &meta_out, nullptr, MetaConfig()); + UnsqueezeKernel(dev_ctx, x, axes, out, nullptr); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/solve_sig.cc b/paddle/phi/ops/compat/solve_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..9771adee8e9836ff012a6b20faa87d79ed28de43 --- /dev/null +++ b/paddle/phi/ops/compat/solve_sig.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SolveGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "solve_grad", {"X", "Y", "Out@GRAD", "Out"}, {}, {"X@GRAD", "Y@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(solve_grad, phi::SolveGradOpArgumentMapping);