diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h index d55c2647c1f3ad59143e4e92dbf002fa860d324e..f70baf486db461d3951ac1b06c61627ed2aa5098 100644 --- a/paddle/fluid/operators/solve_op.h +++ b/paddle/fluid/operators/solve_op.h @@ -157,36 +157,32 @@ static void to_unsqueeze(const framework::ExecutionContext& context, out->Resize(out_dims); } -template -Container infer_size_impl(std::vector a, std::vector b) { - size_t dimsA = a.size(); - size_t dimsB = b.size(); - size_t ndim = dimsA > dimsB ? dimsA : dimsB; - Container expandedSizes(ndim); - - for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) { - ptrdiff_t offset = ndim - 1 - i; - ptrdiff_t dimA = dimsA - 1 - offset; - ptrdiff_t dimB = dimsB - 1 - offset; - int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; - int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; +// Prepared for the broadcast operation +static std::vector get_broadcast_batch_portion( + std::vector x, std::vector y) { + size_t size_x = x.size(); + size_t size_y = y.size(); + size_t size = std::max(size_x, size_y); + std::vector batchPortion(size); + + ptrdiff_t i = (ptrdiff_t)size - 1; + for (; i >= 0; --i) { + ptrdiff_t offset = size - i - 1; + ptrdiff_t dim_x = size_x - offset - 1; + ptrdiff_t dim_y = size_y - offset - 1; + int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; + int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; PADDLE_ENFORCE_EQ( - (sizeA == sizeB || sizeA == 1 || sizeB == 1), true, + (x_size == y_size || x_size == 1 || y_size == 1), true, platform::errors::PreconditionNotMet( - "The size of tensor a (%d) must match the size of tensor b " + "The size of tensor x (%d) must match the size of tensor y " "(%d) at non-singleton dimension %d.", - sizeA, sizeB, i)); + x_size, y_size, i)); - expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; + batchPortion[i] = x_size != 1 ? x_size : y_size; } - return expandedSizes; -} - -// infer size for broadcast operation -static std::vector infer_size(std::vector a, - std::vector b) { - return infer_size_impl>(a, b); + return batchPortion; } // necessary check before expand operation @@ -219,38 +215,34 @@ static void expand_check(const Tensor& arg1, shape_size, MAX_RANK_SUPPORTED)); } -// broadcast the batch dimensions of arg1 and arg2. +// broadcast the batch dimensions of tensor x and tensor y. static inline std::tuple, std::vector> -_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { - std::vector arg1_dims_vec = - paddle::framework::vectorize(arg1.dims()); - std::vector arg2_dims_vec = - paddle::framework::vectorize(arg2.dims()); +get_broadcast_dims(const Tensor& x, const Tensor& y) { + std::vector x_dims_vec = paddle::framework::vectorize(x.dims()); + std::vector y_dims_vec = paddle::framework::vectorize(y.dims()); - std::vector::const_iterator f1 = arg1_dims_vec.begin(); - std::vector::const_iterator l1 = arg1_dims_vec.end() - 2; - std::vector arg1_dims_vec_cut(f1, l1); + std::vector::const_iterator f1 = x_dims_vec.begin(); + std::vector::const_iterator l1 = x_dims_vec.end() - 2; + std::vector x_dims_vec_cut(f1, l1); - std::vector::const_iterator f2 = arg2_dims_vec.begin(); - std::vector::const_iterator l2 = arg2_dims_vec.end() - 2; - std::vector arg2_dims_vec_cut(f2, l2); + std::vector::const_iterator f2 = y_dims_vec.begin(); + std::vector::const_iterator l2 = y_dims_vec.end() - 2; + std::vector y_dims_vec_cut(f2, l2); std::vector expand_batch_portion = - infer_size(arg1_dims_vec_cut, arg2_dims_vec_cut); + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); - std::vector arg1_expand_size({expand_batch_portion}); - arg1_expand_size.insert( - arg1_expand_size.end(), - {arg1_dims_vec[static_cast(arg1_dims_vec.size()) - 2], - arg1_dims_vec[static_cast(arg1_dims_vec.size()) - 1]}); + std::vector x_expand_size({expand_batch_portion}); + x_expand_size.insert(x_expand_size.end(), + {x_dims_vec[static_cast(x_dims_vec.size()) - 2], + x_dims_vec[static_cast(x_dims_vec.size()) - 1]}); - std::vector arg2_expand_size({expand_batch_portion}); - arg2_expand_size.insert( - arg2_expand_size.end(), - {arg2_dims_vec[static_cast(arg2_dims_vec.size()) - 2], - arg2_dims_vec[static_cast(arg2_dims_vec.size()) - 1]}); + std::vector y_expand_size({expand_batch_portion}); + y_expand_size.insert(y_expand_size.end(), + {y_dims_vec[static_cast(y_dims_vec.size()) - 2], + y_dims_vec[static_cast(y_dims_vec.size()) - 1]}); - return std::make_tuple(arg1_expand_size, arg2_expand_size); + return std::make_tuple(x_expand_size, y_expand_size); } template @@ -362,7 +354,7 @@ static void linalg_solve(const framework::ExecutionContext& context, std::vector x_broadcast_dims; std::vector y_broadcast_dims; std::tie(x_broadcast_dims, y_broadcast_dims) = - _broadcast_batch_dims(tmp_x, tmp_y); + get_broadcast_dims(tmp_x, tmp_y); expand_check(tmp_x, x_broadcast_dims); expand_check(tmp_y, y_broadcast_dims); @@ -566,7 +558,7 @@ class SolveGradKernel : public framework::OpKernel { std::vector x_broadcast_dims; std::vector y_broadcast_dims; std::tie(x_broadcast_dims, y_broadcast_dims) = - _broadcast_batch_dims(tmp_x, tmp_y); + get_broadcast_dims(tmp_x, tmp_y); // tmp_dx Tensor tmp_dx; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 919ae418ab19b466f1a8950cde8d869dd902cd48..59d0df85bb14e9298823fef431d921d208b55e98 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -991,7 +991,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)