未验证 提交 a787b278 编写于 作者: W Weilong Wu 提交者: GitHub

Optimized the solve op code:renamed var and removed template func (#36981) (#37011)

    Renamed the variable and function
    Removed the original template function
    Removed the tests_properties in CMakeLists.txt
上级 76cab751
......@@ -157,36 +157,32 @@ static void to_unsqueeze(const framework::ExecutionContext& context,
out->Resize(out_dims);
}
template <typename Container>
Container infer_size_impl(std::vector<int64_t> a, std::vector<int64_t> b) {
size_t dimsA = a.size();
size_t dimsB = b.size();
size_t ndim = dimsA > dimsB ? dimsA : dimsB;
Container expandedSizes(ndim);
for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
ptrdiff_t offset = ndim - 1 - i;
ptrdiff_t dimA = dimsA - 1 - offset;
ptrdiff_t dimB = dimsB - 1 - offset;
int64_t sizeA = (dimA >= 0) ? a[dimA] : 1;
int64_t sizeB = (dimB >= 0) ? b[dimB] : 1;
// Prepared for the broadcast operation
static std::vector<int64_t> get_broadcast_batch_portion(
std::vector<int64_t> x, std::vector<int64_t> y) {
size_t size_x = x.size();
size_t size_y = y.size();
size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);
ptrdiff_t i = (ptrdiff_t)size - 1;
for (; i >= 0; --i) {
ptrdiff_t offset = size - i - 1;
ptrdiff_t dim_x = size_x - offset - 1;
ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;
PADDLE_ENFORCE_EQ(
(sizeA == sizeB || sizeA == 1 || sizeB == 1), true,
(x_size == y_size || x_size == 1 || y_size == 1), true,
platform::errors::PreconditionNotMet(
"The size of tensor a (%d) must match the size of tensor b "
"The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.",
sizeA, sizeB, i));
x_size, y_size, i));
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
batchPortion[i] = x_size != 1 ? x_size : y_size;
}
return expandedSizes;
}
// infer size for broadcast operation
static std::vector<int64_t> infer_size(std::vector<int64_t> a,
std::vector<int64_t> b) {
return infer_size_impl<std::vector<int64_t>>(a, b);
return batchPortion;
}
// necessary check before expand operation
......@@ -219,38 +215,34 @@ static void expand_check(const Tensor& arg1,
shape_size, MAX_RANK_SUPPORTED));
}
// broadcast the batch dimensions of arg1 and arg2.
// broadcast the batch dimensions of tensor x and tensor y.
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
std::vector<int64_t> arg1_dims_vec =
paddle::framework::vectorize(arg1.dims());
std::vector<int64_t> arg2_dims_vec =
paddle::framework::vectorize(arg2.dims());
get_broadcast_dims(const Tensor& x, const Tensor& y) {
std::vector<int64_t> x_dims_vec = paddle::framework::vectorize(x.dims());
std::vector<int64_t> y_dims_vec = paddle::framework::vectorize(y.dims());
std::vector<int64_t>::const_iterator f1 = arg1_dims_vec.begin();
std::vector<int64_t>::const_iterator l1 = arg1_dims_vec.end() - 2;
std::vector<int64_t> arg1_dims_vec_cut(f1, l1);
std::vector<int64_t>::const_iterator f1 = x_dims_vec.begin();
std::vector<int64_t>::const_iterator l1 = x_dims_vec.end() - 2;
std::vector<int64_t> x_dims_vec_cut(f1, l1);
std::vector<int64_t>::const_iterator f2 = arg2_dims_vec.begin();
std::vector<int64_t>::const_iterator l2 = arg2_dims_vec.end() - 2;
std::vector<int64_t> arg2_dims_vec_cut(f2, l2);
std::vector<int64_t>::const_iterator f2 = y_dims_vec.begin();
std::vector<int64_t>::const_iterator l2 = y_dims_vec.end() - 2;
std::vector<int64_t> y_dims_vec_cut(f2, l2);
std::vector<int64_t> expand_batch_portion =
infer_size(arg1_dims_vec_cut, arg2_dims_vec_cut);
get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> arg1_expand_size({expand_batch_portion});
arg1_expand_size.insert(
arg1_expand_size.end(),
{arg1_dims_vec[static_cast<int>(arg1_dims_vec.size()) - 2],
arg1_dims_vec[static_cast<int>(arg1_dims_vec.size()) - 1]});
std::vector<int64_t> x_expand_size({expand_batch_portion});
x_expand_size.insert(x_expand_size.end(),
{x_dims_vec[static_cast<int>(x_dims_vec.size()) - 2],
x_dims_vec[static_cast<int>(x_dims_vec.size()) - 1]});
std::vector<int64_t> arg2_expand_size({expand_batch_portion});
arg2_expand_size.insert(
arg2_expand_size.end(),
{arg2_dims_vec[static_cast<int>(arg2_dims_vec.size()) - 2],
arg2_dims_vec[static_cast<int>(arg2_dims_vec.size()) - 1]});
std::vector<int64_t> y_expand_size({expand_batch_portion});
y_expand_size.insert(y_expand_size.end(),
{y_dims_vec[static_cast<int>(y_dims_vec.size()) - 2],
y_dims_vec[static_cast<int>(y_dims_vec.size()) - 1]});
return std::make_tuple(arg1_expand_size, arg2_expand_size);
return std::make_tuple(x_expand_size, y_expand_size);
}
template <int Rank, typename T, typename DeviceContext>
......@@ -362,7 +354,7 @@ static void linalg_solve(const framework::ExecutionContext& context,
std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) =
_broadcast_batch_dims(tmp_x, tmp_y);
get_broadcast_dims(tmp_x, tmp_y);
expand_check(tmp_x, x_broadcast_dims);
expand_check(tmp_y, y_broadcast_dims);
......@@ -566,7 +558,7 @@ class SolveGradKernel : public framework::OpKernel<T> {
std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) =
_broadcast_batch_dims(tmp_x, tmp_y);
get_broadcast_dims(tmp_x, tmp_y);
// tmp_dx
Tensor tmp_dx;
......
......@@ -991,7 +991,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120)
set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册