未验证 提交 668e235c 编写于 作者: X xiongkun 提交者: GitHub

change einsum_v2 as default and add new flags: FLAG_einsum_opt=1|0 (#43010)

上级 905d857c
......@@ -848,3 +848,16 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
* Example:
*/
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
/**
* Preformance related FLAG
* Name: einsum_opt
* Since Version: 2.3.0
* Value Range: bool, default=false
* Example:
* Note: If True, EinsumOp will be optimimzed by innercache reuse, which
* uses more gpu memory.
*/
PADDLE_DEFINE_EXPORTED_bool(
einsum_opt, false,
"EinsumOp backward will be speedup at the expense of more gpu memory.");
......@@ -20,6 +20,8 @@
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/utils/string/string_helper.h"
DECLARE_bool(einsum_opt);
namespace phi {
// check the validation of the Einsum equation.
......@@ -456,7 +458,7 @@ DenseTensor PerformContraction(
}
// reduction
DenseTensor trans_t;
if (use_cache && cache[operand_idx] != nullptr &&
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
......@@ -465,7 +467,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (cache[operand_idx] != nullptr)
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
......
......@@ -18,6 +18,9 @@ import unittest
import paddle
from paddle.fluid import core
import os
os.environ['FLAGS_new_einsum'] = "0"
class TestErrors(unittest.TestCase):
def setUp(self):
......
......@@ -983,7 +983,7 @@ def einsum(equation, *operands):
# [0.51476848, 0.23367381, 0.39229113]]])
"""
import os
if int(os.environ.get('FLAGS_new_einsum', "0")):
if int(os.environ.get('FLAGS_new_einsum', "1")):
return einsum_v2(equation, *operands)
nop = len(operands)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册