未验证 提交 668e235c 编写于 作者: X xiongkun 提交者: GitHub

change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0 (#43010)

上级 905d857c
...@@ -848,3 +848,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); ...@@ -848,3 +848,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
* Example: * Example:
*/ */
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
/**
* Preformance related FLAG
* Name: einsum_opt
* Since Version: 2.3.0
* Value Range: bool, default=false
* Example:
* Note: If True, EinsumOp will be optimimzed by innercache reuse, which
* uses more gpu memory.
*/
PADDLE_DEFINE_EXPORTED_bool(
einsum_opt, false,
"EinsumOp backward will be speedup at the expense of more gpu memory.");
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/string/string_helper.h" #include "paddle/utils/string/string_helper.h"
DECLARE_bool(einsum_opt);
namespace phi { namespace phi {
// check the validation of the Einsum equation. // check the validation of the Einsum equation.
...@@ -456,7 +458,7 @@ DenseTensor PerformContraction( ...@@ -456,7 +458,7 @@ DenseTensor PerformContraction(
} }
// reduction // reduction
DenseTensor trans_t; DenseTensor trans_t;
if (use_cache && cache[operand_idx] != nullptr && if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) { cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx])); trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!"; VLOG(5) << "Cache Used!";
...@@ -465,7 +467,7 @@ DenseTensor PerformContraction( ...@@ -465,7 +467,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type); dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>( trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (cache[operand_idx] != nullptr) if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t); cache[operand_idx]->ShareBufferWith(trans_t);
} }
auto mul_dims = GetShapeByType<int>(all_labels, auto mul_dims = GetShapeByType<int>(all_labels,
......
...@@ -18,6 +18,9 @@ import unittest ...@@ -18,6 +18,9 @@ import unittest
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
import os
os.environ['FLAGS_new_einsum'] = "0"
class TestErrors(unittest.TestCase): class TestErrors(unittest.TestCase):
def setUp(self): def setUp(self):
......
...@@ -983,7 +983,7 @@ def einsum(equation, *operands): ...@@ -983,7 +983,7 @@ def einsum(equation, *operands):
# [0.51476848, 0.23367381, 0.39229113]]]) # [0.51476848, 0.23367381, 0.39229113]]])
""" """
import os import os
if int(os.environ.get('FLAGS_new_einsum', "0")): if int(os.environ.get('FLAGS_new_einsum', "1")):
return einsum_v2(equation, *operands) return einsum_v2(equation, *operands)
nop = len(operands) nop = len(operands)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册