未验证 提交 a787b278 编写于 作者: W Weilong Wu 提交者: GitHub

Optimized the solve op code:renamed var and removed template func (#36981) (#37011)

    Renamed the variable and function
    Removed the original template function
    Removed the tests_properties in CMakeLists.txt
上级 76cab751
...@@ -157,36 +157,32 @@ static void to_unsqueeze(const framework::ExecutionContext& context, ...@@ -157,36 +157,32 @@ static void to_unsqueeze(const framework::ExecutionContext& context,
out->Resize(out_dims); out->Resize(out_dims);
} }
template <typename Container> // Prepared for the broadcast operation
Container infer_size_impl(std::vector<int64_t> a, std::vector<int64_t> b) { static std::vector<int64_t> get_broadcast_batch_portion(
size_t dimsA = a.size(); std::vector<int64_t> x, std::vector<int64_t> y) {
size_t dimsB = b.size(); size_t size_x = x.size();
size_t ndim = dimsA > dimsB ? dimsA : dimsB; size_t size_y = y.size();
Container expandedSizes(ndim); size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);
for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
ptrdiff_t offset = ndim - 1 - i; ptrdiff_t i = (ptrdiff_t)size - 1;
ptrdiff_t dimA = dimsA - 1 - offset; for (; i >= 0; --i) {
ptrdiff_t dimB = dimsB - 1 - offset; ptrdiff_t offset = size - i - 1;
int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; ptrdiff_t dim_x = size_x - offset - 1;
int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(sizeA == sizeB || sizeA == 1 || sizeB == 1), true, (x_size == y_size || x_size == 1 || y_size == 1), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The size of tensor a (%d) must match the size of tensor b " "The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.", "(%d) at non-singleton dimension %d.",
sizeA, sizeB, i)); x_size, y_size, i));
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; batchPortion[i] = x_size != 1 ? x_size : y_size;
} }
return expandedSizes; return batchPortion;
}
// infer size for broadcast operation
static std::vector<int64_t> infer_size(std::vector<int64_t> a,
std::vector<int64_t> b) {
return infer_size_impl<std::vector<int64_t>>(a, b);
} }
// necessary check before expand operation // necessary check before expand operation
...@@ -219,38 +215,34 @@ static void expand_check(const Tensor& arg1, ...@@ -219,38 +215,34 @@ static void expand_check(const Tensor& arg1,
shape_size, MAX_RANK_SUPPORTED)); shape_size, MAX_RANK_SUPPORTED));
} }
// broadcast the batch dimensions of arg1 and arg2. // broadcast the batch dimensions of tensor x and tensor y.
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { get_broadcast_dims(const Tensor& x, const Tensor& y) {
std::vector<int64_t> arg1_dims_vec = std::vector<int64_t> x_dims_vec = paddle::framework::vectorize(x.dims());
paddle::framework::vectorize(arg1.dims()); std::vector<int64_t> y_dims_vec = paddle::framework::vectorize(y.dims());
std::vector<int64_t> arg2_dims_vec =
paddle::framework::vectorize(arg2.dims());
std::vector<int64_t>::const_iterator f1 = arg1_dims_vec.begin(); std::vector<int64_t>::const_iterator f1 = x_dims_vec.begin();
std::vector<int64_t>::const_iterator l1 = arg1_dims_vec.end() - 2; std::vector<int64_t>::const_iterator l1 = x_dims_vec.end() - 2;
std::vector<int64_t> arg1_dims_vec_cut(f1, l1); std::vector<int64_t> x_dims_vec_cut(f1, l1);
std::vector<int64_t>::const_iterator f2 = arg2_dims_vec.begin(); std::vector<int64_t>::const_iterator f2 = y_dims_vec.begin();
std::vector<int64_t>::const_iterator l2 = arg2_dims_vec.end() - 2; std::vector<int64_t>::const_iterator l2 = y_dims_vec.end() - 2;
std::vector<int64_t> arg2_dims_vec_cut(f2, l2); std::vector<int64_t> y_dims_vec_cut(f2, l2);
std::vector<int64_t> expand_batch_portion = std::vector<int64_t> expand_batch_portion =
infer_size(arg1_dims_vec_cut, arg2_dims_vec_cut); get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> arg1_expand_size({expand_batch_portion}); std::vector<int64_t> x_expand_size({expand_batch_portion});
arg1_expand_size.insert( x_expand_size.insert(x_expand_size.end(),
arg1_expand_size.end(), {x_dims_vec[static_cast<int>(x_dims_vec.size()) - 2],
{arg1_dims_vec[static_cast<int>(arg1_dims_vec.size()) - 2], x_dims_vec[static_cast<int>(x_dims_vec.size()) - 1]});
arg1_dims_vec[static_cast<int>(arg1_dims_vec.size()) - 1]});
std::vector<int64_t> arg2_expand_size({expand_batch_portion}); std::vector<int64_t> y_expand_size({expand_batch_portion});
arg2_expand_size.insert( y_expand_size.insert(y_expand_size.end(),
arg2_expand_size.end(), {y_dims_vec[static_cast<int>(y_dims_vec.size()) - 2],
{arg2_dims_vec[static_cast<int>(arg2_dims_vec.size()) - 2], y_dims_vec[static_cast<int>(y_dims_vec.size()) - 1]});
arg2_dims_vec[static_cast<int>(arg2_dims_vec.size()) - 1]});
return std::make_tuple(arg1_expand_size, arg2_expand_size); return std::make_tuple(x_expand_size, y_expand_size);
} }
template <int Rank, typename T, typename DeviceContext> template <int Rank, typename T, typename DeviceContext>
...@@ -362,7 +354,7 @@ static void linalg_solve(const framework::ExecutionContext& context, ...@@ -362,7 +354,7 @@ static void linalg_solve(const framework::ExecutionContext& context,
std::vector<int64_t> x_broadcast_dims; std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims; std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) = std::tie(x_broadcast_dims, y_broadcast_dims) =
_broadcast_batch_dims(tmp_x, tmp_y); get_broadcast_dims(tmp_x, tmp_y);
expand_check(tmp_x, x_broadcast_dims); expand_check(tmp_x, x_broadcast_dims);
expand_check(tmp_y, y_broadcast_dims); expand_check(tmp_y, y_broadcast_dims);
...@@ -566,7 +558,7 @@ class SolveGradKernel : public framework::OpKernel<T> { ...@@ -566,7 +558,7 @@ class SolveGradKernel : public framework::OpKernel<T> {
std::vector<int64_t> x_broadcast_dims; std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims; std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) = std::tie(x_broadcast_dims, y_broadcast_dims) =
_broadcast_batch_dims(tmp_x, tmp_y); get_broadcast_dims(tmp_x, tmp_y);
// tmp_dx // tmp_dx
Tensor tmp_dx; Tensor tmp_dx;
......
...@@ -991,7 +991,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) ...@@ -991,7 +991,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120)
set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册