未验证 提交 bea0c9f5 编写于 作者: W Weilong Wu 提交者: GitHub

Optimized the solve op code:renamed var and removed template func (#36981)

上级 3705b12c
...@@ -157,70 +157,62 @@ static void to_unsqueeze(const framework::ExecutionContext& context, ...@@ -157,70 +157,62 @@ static void to_unsqueeze(const framework::ExecutionContext& context,
out->Resize(out_dims); out->Resize(out_dims);
} }
template <typename Container> // Prepared for the broadcast operation
Container infer_size_impl(std::vector<int64_t> a, std::vector<int64_t> b) { static std::vector<int64_t> get_broadcast_batch_portion(
size_t dimsA = a.size(); std::vector<int64_t> x, std::vector<int64_t> y) {
size_t dimsB = b.size(); size_t size_x = x.size();
size_t ndim = dimsA > dimsB ? dimsA : dimsB; size_t size_y = y.size();
Container expandedSizes(ndim); size_t size = std::max(size_x, size_y);
std::vector<int64_t> batchPortion(size);
for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
ptrdiff_t offset = ndim - 1 - i; ptrdiff_t i = (ptrdiff_t)size - 1;
ptrdiff_t dimA = dimsA - 1 - offset; for (; i >= 0; --i) {
ptrdiff_t dimB = dimsB - 1 - offset; ptrdiff_t offset = size - i - 1;
int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; ptrdiff_t dim_x = size_x - offset - 1;
int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; ptrdiff_t dim_y = size_y - offset - 1;
int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1;
int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(sizeA == sizeB || sizeA == 1 || sizeB == 1), true, (x_size == y_size || x_size == 1 || y_size == 1), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The size of tensor a (%d) must match the size of tensor b " "The size of tensor x (%d) must match the size of tensor y "
"(%d) at non-singleton dimension %d.", "(%d) at non-singleton dimension %d.",
sizeA, sizeB, i)); x_size, y_size, i));
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; batchPortion[i] = x_size != 1 ? x_size : y_size;
} }
return expandedSizes; return batchPortion;
} }
// infer size for broadcast operation // broadcast the batch dimensions of tensor x and tensor y.
static std::vector<int64_t> infer_size(std::vector<int64_t> a,
std::vector<int64_t> b) {
return infer_size_impl<std::vector<int64_t>>(a, b);
}
// broadcast the batch dimensions of arg1 and arg2.
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { get_broadcast_dims(const Tensor& x, const Tensor& y) {
std::vector<int64_t> arg1_dims_vec = std::vector<int64_t> x_dims_vec = paddle::framework::vectorize(x.dims());
paddle::framework::vectorize(arg1.dims()); std::vector<int64_t> y_dims_vec = paddle::framework::vectorize(y.dims());
std::vector<int64_t> arg2_dims_vec =
paddle::framework::vectorize(arg2.dims());
std::vector<int64_t>::const_iterator f1 = arg1_dims_vec.begin(); std::vector<int64_t>::const_iterator f1 = x_dims_vec.begin();
std::vector<int64_t>::const_iterator l1 = arg1_dims_vec.end() - 2; std::vector<int64_t>::const_iterator l1 = x_dims_vec.end() - 2;
std::vector<int64_t> arg1_dims_vec_cut(f1, l1); std::vector<int64_t> x_dims_vec_cut(f1, l1);
std::vector<int64_t>::const_iterator f2 = arg2_dims_vec.begin(); std::vector<int64_t>::const_iterator f2 = y_dims_vec.begin();
std::vector<int64_t>::const_iterator l2 = arg2_dims_vec.end() - 2; std::vector<int64_t>::const_iterator l2 = y_dims_vec.end() - 2;
std::vector<int64_t> arg2_dims_vec_cut(f2, l2); std::vector<int64_t> y_dims_vec_cut(f2, l2);
std::vector<int64_t> expand_batch_portion = std::vector<int64_t> expand_batch_portion =
infer_size(arg1_dims_vec_cut, arg2_dims_vec_cut); get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> arg1_expand_size({expand_batch_portion}); std::vector<int64_t> x_expand_size({expand_batch_portion});
arg1_expand_size.insert( x_expand_size.insert(x_expand_size.end(),
arg1_expand_size.end(), {x_dims_vec[static_cast<int>(x_dims_vec.size()) - 2],
{arg1_dims_vec[static_cast<int>(arg1_dims_vec.size()) - 2], x_dims_vec[static_cast<int>(x_dims_vec.size()) - 1]});
arg1_dims_vec[static_cast<int>(arg1_dims_vec.size()) - 1]});
std::vector<int64_t> arg2_expand_size({expand_batch_portion}); std::vector<int64_t> y_expand_size({expand_batch_portion});
arg2_expand_size.insert( y_expand_size.insert(y_expand_size.end(),
arg2_expand_size.end(), {y_dims_vec[static_cast<int>(y_dims_vec.size()) - 2],
{arg2_dims_vec[static_cast<int>(arg2_dims_vec.size()) - 2], y_dims_vec[static_cast<int>(y_dims_vec.size()) - 1]});
arg2_dims_vec[static_cast<int>(arg2_dims_vec.size()) - 1]});
return std::make_tuple(arg1_expand_size, arg2_expand_size); return std::make_tuple(x_expand_size, y_expand_size);
} }
template <int Rank, typename T, typename DeviceContext> template <int Rank, typename T, typename DeviceContext>
...@@ -364,7 +356,7 @@ static void linalg_solve(const framework::ExecutionContext& context, ...@@ -364,7 +356,7 @@ static void linalg_solve(const framework::ExecutionContext& context,
std::vector<int64_t> x_broadcast_dims; std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims; std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) = std::tie(x_broadcast_dims, y_broadcast_dims) =
_broadcast_batch_dims(tmp_x, tmp_y); get_broadcast_dims(tmp_x, tmp_y);
Tensor tmp_x_bc; Tensor tmp_x_bc;
TensorExpand<T, DeviceContext>(dev_ctx, tmp_x, &tmp_x_bc, x_broadcast_dims); TensorExpand<T, DeviceContext>(dev_ctx, tmp_x, &tmp_x_bc, x_broadcast_dims);
...@@ -510,7 +502,7 @@ class SolveGradKernel : public framework::OpKernel<T> { ...@@ -510,7 +502,7 @@ class SolveGradKernel : public framework::OpKernel<T> {
std::vector<int64_t> x_broadcast_dims; std::vector<int64_t> x_broadcast_dims;
std::vector<int64_t> y_broadcast_dims; std::vector<int64_t> y_broadcast_dims;
std::tie(x_broadcast_dims, y_broadcast_dims) = std::tie(x_broadcast_dims, y_broadcast_dims) =
_broadcast_batch_dims(tmp_x, tmp_y); get_broadcast_dims(tmp_x, tmp_y);
// tmp_dx // tmp_dx
Tensor tmp_dx; Tensor tmp_dx;
......
...@@ -63,7 +63,7 @@ class TriangularSolveOp : public framework::OperatorWithKernel { ...@@ -63,7 +63,7 @@ class TriangularSolveOp : public framework::OperatorWithKernel {
y_dims_vec.end() - 2); y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion = std::vector<int64_t> expand_batch_portion =
infer_size(x_dims_vec_cut, y_dims_vec_cut); get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> y_broadcast_dims({expand_batch_portion}); std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2], y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2],
......
...@@ -36,7 +36,7 @@ static void triangular_solve(const DeviceContext& context, const Tensor& x, ...@@ -36,7 +36,7 @@ static void triangular_solve(const DeviceContext& context, const Tensor& x,
// Tensor broadcast use eigen // Tensor broadcast use eigen
std::vector<int64_t> x_bst_dims_vec; std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec; std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = _broadcast_batch_dims(x, y); std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y);
Tensor x_bst(x.type()); Tensor x_bst(x.type());
TensorExpand<T, DeviceContext>(context, x, &x_bst, x_bst_dims_vec); TensorExpand<T, DeviceContext>(context, x, &x_bst, x_bst_dims_vec);
...@@ -141,7 +141,7 @@ class TriangularSolveGradKernel : public framework::OpKernel<T> { ...@@ -141,7 +141,7 @@ class TriangularSolveGradKernel : public framework::OpKernel<T> {
std::vector<int64_t> x_bst_dims_vec; std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec; std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = _broadcast_batch_dims(*x, *y); std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(*x, *y);
Tensor dy_bst(y->type()); Tensor dy_bst(y->type());
if (dy) { if (dy) {
......
...@@ -1019,7 +1019,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) ...@@ -1019,7 +1019,6 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120)
set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册