diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h index c46a1cc06688383637e5de4a4d3becef821cbe91..ec72269f697e87b5cb957682312d0a1fa7a8d506 100644 --- a/paddle/fluid/operators/solve_op.h +++ b/paddle/fluid/operators/solve_op.h @@ -157,70 +157,62 @@ static void to_unsqueeze(const framework::ExecutionContext& context, out->Resize(out_dims); } -template -Container infer_size_impl(std::vector a, std::vector b) { - size_t dimsA = a.size(); - size_t dimsB = b.size(); - size_t ndim = dimsA > dimsB ? dimsA : dimsB; - Container expandedSizes(ndim); - - for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) { - ptrdiff_t offset = ndim - 1 - i; - ptrdiff_t dimA = dimsA - 1 - offset; - ptrdiff_t dimB = dimsB - 1 - offset; - int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; - int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; +// Prepared for the broadcast operation +static std::vector get_broadcast_batch_portion( + std::vector x, std::vector y) { + size_t size_x = x.size(); + size_t size_y = y.size(); + size_t size = std::max(size_x, size_y); + std::vector batchPortion(size); + + ptrdiff_t i = (ptrdiff_t)size - 1; + for (; i >= 0; --i) { + ptrdiff_t offset = size - i - 1; + ptrdiff_t dim_x = size_x - offset - 1; + ptrdiff_t dim_y = size_y - offset - 1; + int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; + int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; PADDLE_ENFORCE_EQ( - (sizeA == sizeB || sizeA == 1 || sizeB == 1), true, + (x_size == y_size || x_size == 1 || y_size == 1), true, platform::errors::PreconditionNotMet( - "The size of tensor a (%d) must match the size of tensor b " + "The size of tensor x (%d) must match the size of tensor y " "(%d) at non-singleton dimension %d.", - sizeA, sizeB, i)); + x_size, y_size, i)); - expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; + batchPortion[i] = x_size != 1 ? x_size : y_size; } - return expandedSizes; + return batchPortion; } -// infer size for broadcast operation -static std::vector infer_size(std::vector a, - std::vector b) { - return infer_size_impl>(a, b); -} - -// broadcast the batch dimensions of arg1 and arg2. +// broadcast the batch dimensions of tensor x and tensor y. static inline std::tuple, std::vector> -_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { - std::vector arg1_dims_vec = - paddle::framework::vectorize(arg1.dims()); - std::vector arg2_dims_vec = - paddle::framework::vectorize(arg2.dims()); +get_broadcast_dims(const Tensor& x, const Tensor& y) { + std::vector x_dims_vec = paddle::framework::vectorize(x.dims()); + std::vector y_dims_vec = paddle::framework::vectorize(y.dims()); - std::vector::const_iterator f1 = arg1_dims_vec.begin(); - std::vector::const_iterator l1 = arg1_dims_vec.end() - 2; - std::vector arg1_dims_vec_cut(f1, l1); + std::vector::const_iterator f1 = x_dims_vec.begin(); + std::vector::const_iterator l1 = x_dims_vec.end() - 2; + std::vector x_dims_vec_cut(f1, l1); - std::vector::const_iterator f2 = arg2_dims_vec.begin(); - std::vector::const_iterator l2 = arg2_dims_vec.end() - 2; - std::vector arg2_dims_vec_cut(f2, l2); + std::vector::const_iterator f2 = y_dims_vec.begin(); + std::vector::const_iterator l2 = y_dims_vec.end() - 2; + std::vector y_dims_vec_cut(f2, l2); std::vector expand_batch_portion = - infer_size(arg1_dims_vec_cut, arg2_dims_vec_cut); + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); - std::vector arg1_expand_size({expand_batch_portion}); - arg1_expand_size.insert( - arg1_expand_size.end(), - {arg1_dims_vec[static_cast(arg1_dims_vec.size()) - 2], - arg1_dims_vec[static_cast(arg1_dims_vec.size()) - 1]}); + std::vector x_expand_size({expand_batch_portion}); + x_expand_size.insert(x_expand_size.end(), + {x_dims_vec[static_cast(x_dims_vec.size()) - 2], + x_dims_vec[static_cast(x_dims_vec.size()) - 1]}); - std::vector arg2_expand_size({expand_batch_portion}); - arg2_expand_size.insert( - arg2_expand_size.end(), - {arg2_dims_vec[static_cast(arg2_dims_vec.size()) - 2], - arg2_dims_vec[static_cast(arg2_dims_vec.size()) - 1]}); + std::vector y_expand_size({expand_batch_portion}); + y_expand_size.insert(y_expand_size.end(), + {y_dims_vec[static_cast(y_dims_vec.size()) - 2], + y_dims_vec[static_cast(y_dims_vec.size()) - 1]}); - return std::make_tuple(arg1_expand_size, arg2_expand_size); + return std::make_tuple(x_expand_size, y_expand_size); } template @@ -364,7 +356,7 @@ static void linalg_solve(const framework::ExecutionContext& context, std::vector x_broadcast_dims; std::vector y_broadcast_dims; std::tie(x_broadcast_dims, y_broadcast_dims) = - _broadcast_batch_dims(tmp_x, tmp_y); + get_broadcast_dims(tmp_x, tmp_y); Tensor tmp_x_bc; TensorExpand(dev_ctx, tmp_x, &tmp_x_bc, x_broadcast_dims); @@ -510,7 +502,7 @@ class SolveGradKernel : public framework::OpKernel { std::vector x_broadcast_dims; std::vector y_broadcast_dims; std::tie(x_broadcast_dims, y_broadcast_dims) = - _broadcast_batch_dims(tmp_x, tmp_y); + get_broadcast_dims(tmp_x, tmp_y); // tmp_dx Tensor tmp_dx; diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc index 202757ec48d83d789439d63f18395ed52188575a..4b01669bf55b407896017e0f598fe6e73cf534ba 100644 --- a/paddle/fluid/operators/triangular_solve_op.cc +++ b/paddle/fluid/operators/triangular_solve_op.cc @@ -63,7 +63,7 @@ class TriangularSolveOp : public framework::OperatorWithKernel { y_dims_vec.end() - 2); std::vector expand_batch_portion = - infer_size(x_dims_vec_cut, y_dims_vec_cut); + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); std::vector y_broadcast_dims({expand_batch_portion}); y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2], diff --git a/paddle/fluid/operators/triangular_solve_op.h b/paddle/fluid/operators/triangular_solve_op.h index 158ad72ddbfcdb21be7043a7ffb170e1ec83a89d..f64b016366e39b2260f4f8aebbb2e371ee2a8a7a 100644 --- a/paddle/fluid/operators/triangular_solve_op.h +++ b/paddle/fluid/operators/triangular_solve_op.h @@ -36,7 +36,7 @@ static void triangular_solve(const DeviceContext& context, const Tensor& x, // Tensor broadcast use eigen std::vector x_bst_dims_vec; std::vector y_bst_dims_vec; - std::tie(x_bst_dims_vec, y_bst_dims_vec) = _broadcast_batch_dims(x, y); + std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y); Tensor x_bst(x.type()); TensorExpand(context, x, &x_bst, x_bst_dims_vec); @@ -141,7 +141,7 @@ class TriangularSolveGradKernel : public framework::OpKernel { std::vector x_bst_dims_vec; std::vector y_bst_dims_vec; - std::tie(x_bst_dims_vec, y_bst_dims_vec) = _broadcast_batch_dims(*x, *y); + std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(*x, *y); Tensor dy_bst(y->type()); if (dy) { diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9d53757188e739e0100cc9efb539fdfe3daf1292..e821140a0d1ec9f585c80b036d9ea6a465b23987 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1019,7 +1019,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)