diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 600a4cbcc3ed9969e021beaa2bda1ff23c89bb3b..2fcc573456d42fc1c32c479b5dd23594ef0c5dd8 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -848,3 +848,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); * Example: */ PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); + +/** + * Preformance related FLAG + * Name: einsum_opt + * Since Version: 2.3.0 + * Value Range: bool, default=false + * Example: + * Note: If True, EinsumOp will be optimimzed by innercache reuse, which + * uses more gpu memory. + */ +PADDLE_DEFINE_EXPORTED_bool( + einsum_opt, false, + "EinsumOp backward will be speedup at the expense of more gpu memory."); diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 5e4480426c0ccee3f5848b6b1462d3644fca67bf..bfbd6e0c51cfc7b6592b78335810c75e449fb042 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -20,6 +20,8 @@ #include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/utils/string/string_helper.h" +DECLARE_bool(einsum_opt); + namespace phi { // check the validation of the Einsum equation. @@ -456,7 +458,7 @@ DenseTensor PerformContraction( } // reduction DenseTensor trans_t; - if (use_cache && cache[operand_idx] != nullptr && + if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr && cache[operand_idx]->IsInitialized()) { trans_t.ShareBufferWith(*(cache[operand_idx])); VLOG(5) << "Cache Used!"; @@ -465,7 +467,7 @@ DenseTensor PerformContraction( dev_ctx, t, perm, all_labels, ellipsis, label2type); trans_t = PerformTranspose( dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); - if (cache[operand_idx] != nullptr) + if (FLAGS_einsum_opt && cache[operand_idx] != nullptr) cache[operand_idx]->ShareBufferWith(trans_t); } auto mul_dims = GetShapeByType(all_labels, diff --git a/python/paddle/fluid/tests/unittests/test_einsum.py b/python/paddle/fluid/tests/unittests/test_einsum.py index 43b5ce96a390150db7e29588e4107271b240b23f..26aaf0f44f1d2ad6d1239bb6b827feb94b8864d3 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum.py +++ b/python/paddle/fluid/tests/unittests/test_einsum.py @@ -18,6 +18,9 @@ import unittest import paddle from paddle.fluid import core +import os +os.environ['FLAGS_new_einsum'] = "0" + class TestErrors(unittest.TestCase): def setUp(self): diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 4cdbebb0552298fea02773a3833a7fe7118deb8a..49cc426a00fd998c2ed24f94fb0002e4466065af 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -983,7 +983,7 @@ def einsum(equation, *operands): # [0.51476848, 0.23367381, 0.39229113]]]) """ import os - if int(os.environ.get('FLAGS_new_einsum', "0")): + if int(os.environ.get('FLAGS_new_einsum', "1")): return einsum_v2(equation, *operands) nop = len(operands)