/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "Eigen/Core" #include "Eigen/LU" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/squeeze_op.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #endif #define MAX_RANK_SUPPORTED 6 namespace paddle { namespace operators { using Tensor = framework::Tensor; using framework::To32BitIndex; constexpr int kMULMKLDNNINT8 = 1; struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} template HOSTDEVICE inline U operator()(const U& x) const { return x; } }; template void ReduceSumForSolveGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, bool keep_dim, const paddle::framework::ExecutionContext& ctx) { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); TensorReduce(*input, output, reduce_dims, static_cast(0), cub::Sum(), IdentityFunctor(), stream); #else ReduceKernelFunctor( input, output, reduce_dims, keep_dim, false, ctx) .template apply(); #endif } // check the input other is vector_case or not static inline bool is_vector_rhs(const Tensor& input, const Tensor& other) { auto x_dim = input.dims(); auto y_dim = other.dims(); auto x_dim_size = x_dim.size(); auto y_dim_size = y_dim.size(); std::vector x_dims_vec = paddle::framework::vectorize(x_dim); std::vector y_dims_vec = paddle::framework::vectorize(y_dim); std::vector::const_iterator f = x_dims_vec.begin(); std::vector::const_iterator l = x_dims_vec.end() - 1; std::vector x_dims_vec_cut(f, l); // input.shape[:-1] std::vector expected_batched_rhs_shape(x_dims_vec_cut); bool vector_case = y_dim_size == 1 || (x_dim_size - 1 == y_dim_size && y_dims_vec == (expected_batched_rhs_shape)); return vector_case; } // unsqueeze operation helper static framework::DDim GetOutputShapeUnsqueeze( const std::vector unsqz_dims, const framework::DDim& in_dims) { int output_size = in_dims.size() + static_cast(unsqz_dims.size()); int cur_output_size = in_dims.size(); std::vector output_shape(output_size, 0); // Validity Check: rank range. PADDLE_ENFORCE_LE(output_size, 6, platform::errors::InvalidArgument( "The output " "tensor's rank should be less than 6.")); for (int axis : unsqz_dims) { int cur = axis < 0 ? axis + cur_output_size + 1 : axis; // Vaildity Check: the axis bound PADDLE_ENFORCE_GE(cur, 0, platform::errors::InvalidArgument( "The insert dimension value should " "not be less than 0")); PADDLE_ENFORCE_LE(cur, cur_output_size, platform::errors::InvalidArgument( "The insert dimension value shoule not be larger " "than the dimension size of input tensor")); // Move old axis, and insert new axis for (int i = cur_output_size; i >= cur; --i) { if (output_shape[i] == 1) { // Move axis output_shape[i + 1] = 1; output_shape[i] = 0; } } output_shape[cur] = 1; // Add the output size. cur_output_size++; } // Make output shape for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { if (output_shape[out_idx] == 0) { output_shape[out_idx] = in_dims[in_idx++]; } } return framework::make_ddim(output_shape); } // operation like squeeze(-1) static void to_squeeze(const framework::ExecutionContext& context, const framework::Tensor& in, framework::Tensor* out) { auto x_dims = in.dims(); std::vector sqz_dims = {-1}; auto out_dims = GetOutputShape(sqz_dims, x_dims, true); out->mutable_data(context.GetPlace(), in.type()); framework::TensorCopy( in, context.GetPlace(), context.template device_context(), out); out->Resize(out_dims); } // vector_case, need to operate like unsqueeze(-1) static void to_unsqueeze(const framework::ExecutionContext& context, const framework::Tensor& in, framework::Tensor* out) { auto x_dims = in.dims(); std::vector unsqz_dims = {-1}; framework::DDim out_dims = out->dims(); out_dims = GetOutputShapeUnsqueeze(unsqz_dims, x_dims); framework::TensorCopy( in, context.GetPlace(), context.template device_context(), out); out->Resize(out_dims); } // Prepared for the broadcast operation static std::vector get_broadcast_batch_portion( std::vector x, std::vector y) { size_t size_x = x.size(); size_t size_y = y.size(); size_t size = std::max(size_x, size_y); std::vector batchPortion(size); ptrdiff_t i = (ptrdiff_t)size - 1; for (; i >= 0; --i) { ptrdiff_t offset = size - i - 1; ptrdiff_t dim_x = size_x - offset - 1; ptrdiff_t dim_y = size_y - offset - 1; int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; PADDLE_ENFORCE_EQ( (x_size == y_size || x_size == 1 || y_size == 1), true, platform::errors::PreconditionNotMet( "The size of tensor x (%d) must match the size of tensor y " "(%d) at non-singleton dimension %d.", x_size, y_size, i)); batchPortion[i] = x_size != 1 ? x_size : y_size; } return batchPortion; } // necessary check before expand operation static void expand_check(const Tensor& arg1, std::vector expand_shape) { auto rank = arg1.dims().size(); PADDLE_ENFORCE_GE( rank, 1, platform::errors::InvalidArgument( "The rank of the input 'X' for expand must be positive, " "but the value received is %d.", rank)); PADDLE_ENFORCE_LE( rank, MAX_RANK_SUPPORTED, platform::errors::InvalidArgument( "The rank of the input 'X' for expand must be less than " "or equal to %d, but the value received is %d.", MAX_RANK_SUPPORTED, rank)); auto shape_size = static_cast(expand_shape.size()); PADDLE_ENFORCE_GE( shape_size, rank, platform::errors::InvalidArgument( "The number (%d) of elements of 'shape' for expand must be " "greater than or equal to the rank (%d) of the input 'X'.", shape_size, rank)); PADDLE_ENFORCE_LE( shape_size, MAX_RANK_SUPPORTED, platform::errors::InvalidArgument( "The number (%d) of elements of 'shape' for expand must be " "less than or equal to %d.", shape_size, MAX_RANK_SUPPORTED)); } // broadcast the batch dimensions of tensor x and tensor y. static inline std::tuple, std::vector> get_broadcast_dims(const Tensor& x, const Tensor& y) { std::vector x_dims_vec = paddle::framework::vectorize(x.dims()); std::vector y_dims_vec = paddle::framework::vectorize(y.dims()); std::vector::const_iterator f1 = x_dims_vec.begin(); std::vector::const_iterator l1 = x_dims_vec.end() - 2; std::vector x_dims_vec_cut(f1, l1); std::vector::const_iterator f2 = y_dims_vec.begin(); std::vector::const_iterator l2 = y_dims_vec.end() - 2; std::vector y_dims_vec_cut(f2, l2); std::vector expand_batch_portion = get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); std::vector x_expand_size({expand_batch_portion}); x_expand_size.insert(x_expand_size.end(), {x_dims_vec[static_cast(x_dims_vec.size()) - 2], x_dims_vec[static_cast(x_dims_vec.size()) - 1]}); std::vector y_expand_size({expand_batch_portion}); y_expand_size.insert(y_expand_size.end(), {y_dims_vec[static_cast(y_dims_vec.size()) - 2], y_dims_vec[static_cast(y_dims_vec.size()) - 1]}); return std::make_tuple(x_expand_size, y_expand_size); } template void tensor_expand(const framework::ExecutionContext& context, const Tensor& arg1, Tensor* out0, std::vector expand_size) { auto in_dims = arg1.dims(); auto expand_shape = expand_size; auto vec_in_dims = framework::vectorize(in_dims); auto diff = expand_shape.size() - vec_in_dims.size(); vec_in_dims.insert(vec_in_dims.begin(), diff, 1); std::vector repeat_times(vec_in_dims.size()); for (size_t i = 0; i < vec_in_dims.size(); ++i) { PADDLE_ENFORCE_NE( expand_shape[i], 0, platform::errors::InvalidArgument("The expanded size cannot be zero.")); if (i < diff) { PADDLE_ENFORCE_GT( expand_shape[i], 0, platform::errors::InvalidArgument( "The expanded size (%d) for non-existing dimensions must be " "positive for expand operation.", expand_shape[i])); repeat_times[i] = expand_shape[i]; } else if (expand_shape[i] > 0) { if (vec_in_dims[i] != 1) { PADDLE_ENFORCE_EQ( vec_in_dims[i], expand_shape[i], platform::errors::InvalidArgument( "The value (%d) of the non-singleton dimension does not match" " the corresponding value (%d) in shape for expand operation.", vec_in_dims[i], expand_shape[i])); repeat_times[i] = 1; } else { repeat_times[i] = expand_shape[i]; } } else { PADDLE_ENFORCE_EQ( expand_shape[i], -1, platform::errors::InvalidArgument( "When the value in shape is negative for expand_v2 op, " "only -1 is supported, but the value received is %d.", expand_shape[i])); repeat_times[i] = 1; } } Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; } framework::DDim new_in_dims = framework::make_ddim(vec_in_dims); framework::DDim out_dims(new_in_dims); for (size_t i = 0; i < repeat_times.size(); ++i) { out_dims[i] *= repeat_times[i]; } out0->Resize(out_dims); auto x = EigenTensor::From(arg1, new_in_dims); out0->mutable_data(context.GetPlace()); auto y = EigenTensor::From(*out0, out_dims); auto& place = *context.template device_context().eigen_device(); // use 32-bit index to speed up bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { EigenBroadcast, T, Rank>::Eval( place, To32BitIndex(y), To32BitIndex(x), bcast_dims); } else { EigenBroadcast, T, Rank>::Eval(place, y, x, bcast_dims); } } template static void linalg_solve(const framework::ExecutionContext& context, const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* out) { out->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); math::MatrixSolveFunctor mat_solve; // input y can be vector or matrix // but need to be unsqueezed if y is a vector bool is_vector = false; is_vector = is_vector_rhs(*x, *y); Tensor tmp_y; if (is_vector) { tmp_y.mutable_data(context.GetPlace(), y->type()); to_unsqueeze(context, *y, &tmp_y); } else { tmp_y.Resize(y->dims()); tmp_y.mutable_data(context.GetPlace(), y->type()); framework::TensorCopy( *y, context.GetPlace(), context.template device_context(), &tmp_y); } Tensor tmp_x; tmp_x.Resize(x->dims()); tmp_x.mutable_data(context.GetPlace(), x->type()); framework::TensorCopy( *x, context.GetPlace(), context.template device_context(), &tmp_x); std::vector x_broadcast_dims; std::vector y_broadcast_dims; std::tie(x_broadcast_dims, y_broadcast_dims) = get_broadcast_dims(tmp_x, tmp_y); expand_check(tmp_x, x_broadcast_dims); expand_check(tmp_y, y_broadcast_dims); Tensor tmp_x_bc; Tensor tmp_y_bc; auto tmp_x_rank = tmp_x.dims().size(); auto tmp_y_rank = tmp_y.dims().size(); auto rank_0 = std::max(tmp_x_rank, static_cast(x_broadcast_dims.size())); switch (rank_0) { case 1: tensor_expand<1, T, DeviceContext>(context, tmp_x, &tmp_x_bc, x_broadcast_dims); break; case 2: tensor_expand<2, T, DeviceContext>(context, tmp_x, &tmp_x_bc, x_broadcast_dims); break; case 3: tensor_expand<3, T, DeviceContext>(context, tmp_x, &tmp_x_bc, x_broadcast_dims); break; case 4: tensor_expand<4, T, DeviceContext>(context, tmp_x, &tmp_x_bc, x_broadcast_dims); break; case 5: tensor_expand<5, T, DeviceContext>(context, tmp_x, &tmp_x_bc, x_broadcast_dims); break; case 6: tensor_expand<6, T, DeviceContext>(context, tmp_x, &tmp_x_bc, x_broadcast_dims); break; } auto rank_1 = std::max(tmp_y_rank, static_cast(y_broadcast_dims.size())); switch (rank_1) { case 1: tensor_expand<1, T, DeviceContext>(context, tmp_y, &tmp_y_bc, y_broadcast_dims); break; case 2: tensor_expand<2, T, DeviceContext>(context, tmp_y, &tmp_y_bc, y_broadcast_dims); break; case 3: tensor_expand<3, T, DeviceContext>(context, tmp_y, &tmp_y_bc, y_broadcast_dims); break; case 4: tensor_expand<4, T, DeviceContext>(context, tmp_y, &tmp_y_bc, y_broadcast_dims); break; case 5: tensor_expand<5, T, DeviceContext>(context, tmp_y, &tmp_y_bc, y_broadcast_dims); break; case 6: tensor_expand<6, T, DeviceContext>(context, tmp_y, &tmp_y_bc, y_broadcast_dims); break; } auto x_dim = x->dims(); auto y_dim = y->dims(); auto x_dim_size = x_dim.size(); auto y_dim_size = y_dim.size(); if (is_vector) { // vector case out->Resize(tmp_y_bc.dims()); // out.unsqueeze(-1) mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); Tensor out_tmp; out_tmp.Resize(out->dims()); out_tmp = *out; to_squeeze(context, out_tmp, out); // out.squeeze(-1) } else { PADDLE_ENFORCE_EQ( x_dim[x_dim_size - 1], y_dim[y_dim_size - 2], platform::errors::InvalidArgument( "Matrix X1 with dimension greater than 2 and any matrix Y1," "the matrix X1's width must be equal with matrix Y1's " "height. But received X's shape = [%s], X1's shape = [%s], X1's " "width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = " "%s.", x_dim, x_dim, x_dim[x_dim_size - 1], y_dim, y_dim, y_dim[y_dim_size - 2])); mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); } } // for TransposeNormal static std::vector getNewAxis(const int b_rank) { std::vector axis_1 = {0}; std::vector axis_2 = {1, 0}; std::vector axis_3 = {0, 2, 1}; std::vector axis_4 = {0, 1, 3, 2}; std::vector axis_5 = {0, 1, 2, 4, 3}; std::vector axis_6 = {0, 1, 2, 3, 5, 4}; std::vector axis_7 = {0, 1, 2, 3, 4, 6, 5}; std::vector axis_8 = {0, 1, 2, 3, 4, 5, 7, 6}; std::vector axis_9 = {0, 1, 2, 3, 4, 5, 6, 8, 7}; switch (b_rank) { case 1: return axis_1; break; case 2: return axis_2; break; case 3: return axis_3; break; case 4: return axis_4; break; case 5: return axis_5; break; case 6: return axis_6; break; case 7: return axis_7; break; case 8: return axis_8; break; default: return axis_9; } } // for Resize static std::vector getNewDimsVec(const DDim& b_dims) { std::vector b_dims_vec = paddle::framework::vectorize(b_dims); int size = b_dims_vec.size(); if (size >= 2) { // swap the last 2 elements in b_dims_vec int64_t temp = b_dims_vec[size - 1]; b_dims_vec[size - 1] = b_dims_vec[size - 2]; b_dims_vec[size - 2] = temp; return b_dims_vec; } PADDLE_ENFORCE_NE( b_dims_vec.empty(), true, platform::errors::PreconditionNotMet( "The size of tensor b must not be %d after getting new dims", 0)); // if b_dims_vec.size() == 1, just retun original vec return b_dims_vec; } template class SolveKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const auto* x = context.Input("X"); const auto* y = context.Input("Y"); Tensor* out = context.Output("Out"); linalg_solve(context, x, y, out); } }; template class SolveGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* dout = ctx.Input(framework::GradVarName("Out")); // reuse the linalg.solve forward output auto* out = ctx.Input("Out"); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); bool is_vector = false; is_vector = is_vector_rhs(*input, *y); Tensor tmp_y; if (is_vector) { tmp_y.mutable_data(ctx.GetPlace(), y->type()); to_unsqueeze(ctx, *y, &tmp_y); } else { tmp_y.Resize(y->dims()); tmp_y.mutable_data(ctx.GetPlace(), y->type()); framework::TensorCopy( *y, ctx.GetPlace(), ctx.template device_context(), &tmp_y); } Tensor tmp_x; tmp_x.Resize(input->dims()); tmp_x.mutable_data(ctx.GetPlace(), input->type()); framework::TensorCopy( *input, ctx.GetPlace(), ctx.template device_context(), &tmp_x); std::vector x_broadcast_dims; std::vector y_broadcast_dims; std::tie(x_broadcast_dims, y_broadcast_dims) = get_broadcast_dims(tmp_x, tmp_y); // tmp_dx Tensor tmp_dx; tmp_dx.Resize(framework::make_ddim(x_broadcast_dims)); tmp_dx.mutable_data(ctx.GetPlace()); // tmp_dy Tensor tmp_dy; tmp_dy.Resize(framework::make_ddim(y_broadcast_dims)); tmp_dy.mutable_data(ctx.GetPlace()); Tensor tmp_input(input->type()); const auto& new_dims_vec = getNewDimsVec(input->dims()); tmp_input.Resize(framework::make_ddim(new_dims_vec)); tmp_input.mutable_data(ctx.GetPlace()); math::TransposeNormal trans; std::vector new_axis = getNewAxis(input->dims().size()); auto& dev_ctx = ctx.template device_context(); trans(dev_ctx, *input, &tmp_input, new_axis); if (dy) { dy->mutable_data(ctx.GetPlace()); // reuse linalg_solve forward logics to get tmp_dy linalg_solve(ctx, &tmp_input, dout, &tmp_dy); } if (dx) { dx->mutable_data(ctx.GetPlace()); // to get dx auto blas = math::GetBlas(ctx); if (input->dims().size() == 2 && y->dims().size() == 2) { auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); auto mat_dim_b1 = math::CreateMatrixDescriptor(out->dims(), 0, true); blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); } else if (is_vector_rhs(*input, *y)) { Tensor tmp_dy_; tmp_dy_.mutable_data(ctx.GetPlace(), y->type()); to_unsqueeze(ctx, tmp_dy, &tmp_dy_); Tensor tmp_out_; tmp_out_.mutable_data(ctx.GetPlace(), out->type()); to_unsqueeze(ctx, *out, &tmp_out_); auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false); auto mat_dim_b1 = math::CreateMatrixDescriptor(tmp_out_.dims(), 0, true); blas.MatMul(tmp_dy_, mat_dim_a1, tmp_out_, mat_dim_b1, T(-1), &tmp_dx, T(0)); } else { auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); auto mat_dim_b1 = math::CreateMatrixDescriptor(out->dims(), 0, true); blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); } } if (y->dims() != tmp_dy.dims()) { Tensor dy_help; dy_help.Resize(tmp_dy.dims()); dy_help.mutable_data(ctx.GetPlace(), tmp_dy.type()); framework::TensorCopy( tmp_dy, ctx.GetPlace(), ctx.template device_context(), &dy_help); // get dims std::vector x_dims = vectorize(input->dims()); std::vector y_dims = vectorize(y->dims()); std::vector dout_dims = vectorize(dout->dims()); if (is_vector_rhs(*input, *y)) { dout_dims.push_back(1); } int y_ndim = y_dims.size(); int ndim = dout_dims.size(); const std::vector dy_help_dims = vectorize(dy_help.dims()); std::vector dy_broadcast_dims(ndim); std::fill(dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); std::copy(y_dims.data(), y_dims.data() + y_ndim, dy_broadcast_dims.data() + ndim - y_ndim); std::vector dy_reduce_dims; for (int idx = 0; idx <= ndim - 3; idx++) { if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { dy_reduce_dims.push_back(idx); } } // reduce sum to get grad by ReduceSum if (dy) { if (dy_reduce_dims.empty()) { *dy = std::move(dy_help); } else { bool keep_dim = true; if (dy_help.dims().size() != dy->dims().size()) { keep_dim = false; } ReduceSumForSolveGrad(&dy_help, dy, dy_reduce_dims, keep_dim, ctx); } dy->Resize(y->dims()); } } else { framework::TensorCopy( tmp_dy, ctx.GetPlace(), ctx.template device_context(), dy); } if (input->dims() != tmp_dx.dims()) { Tensor dx_help; dx_help.Resize(tmp_dx.dims()); dx_help.mutable_data(ctx.GetPlace(), tmp_dx.type()); framework::TensorCopy( tmp_dx, ctx.GetPlace(), ctx.template device_context(), &dx_help); // get dims std::vector x_dims = vectorize(input->dims()); std::vector y_dims = vectorize(y->dims()); int x_ndim = x_dims.size(); int ndim = x_broadcast_dims.size(); const std::vector dx_help_dims = vectorize(dx_help.dims()); std::vector dx_broadcast_dims(ndim); std::fill(dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); std::copy(x_dims.data(), x_dims.data() + x_ndim, dx_broadcast_dims.data() + ndim - x_ndim); std::vector dx_reduce_dims; for (int idx = 0; idx <= ndim - 3; idx++) { if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { dx_reduce_dims.push_back(idx); } } // reduce sum to get grad by ReduceSum if (dx) { dx->mutable_data(ctx.GetPlace()); if (dx_reduce_dims.empty()) { *dx = std::move(dx_help); } else { bool keep_dim = true; if (dx_help.dims().size() != dx->dims().size()) { keep_dim = false; } ReduceSumForSolveGrad(&dx_help, dx, dx_reduce_dims, keep_dim, ctx); } dx->Resize(input->dims()); } } else { framework::TensorCopy( tmp_dx, ctx.GetPlace(), ctx.template device_context(), dx); } } }; } // namespace operators } // namespace paddle