未验证 提交 83abec60 编写于 作者: X xiongkun 提交者: GitHub

[ Make FLAGS_einsum_opt as default ] Einsum memory optimization (#43397)

* change logic for optimize

* modifty

* optimize the backward speed of EinsumOp

* add cache optimizer for einsum op

* EinsumOp: fix new dygraph mode error

* fix bug

* change Cache->InnerCache

* fix code

* fix

* add nan inf utils for einsum op

* add as_extra

* memory optimizer for einsum

* update code
上级 2106f668
...@@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name, ...@@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) { const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
} }
} // namespace egr } // namespace egr
...@@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>; ...@@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>; using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors = using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>; std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>; using TupleOfTensorAndVector =
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>;
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
......
...@@ -41,6 +41,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,6 +41,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsExtra() .AsExtra()
.AsIntermediate(); .AsIntermediate();
AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation", AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`" "(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in " "There must have `->` and the number of operands in "
...@@ -59,8 +63,8 @@ class EinsumGradOp : public framework::OperatorWithKernel { ...@@ -59,8 +63,8 @@ class EinsumGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto x_name = "Operands"; auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name); auto x_grad_name = framework::GradVarName(x_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name)); ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands"));
ctx->ShareAllLoD(x_name, x_grad_name); ctx->ShareAllLoD("Operands", x_grad_name);
} }
protected: protected:
...@@ -79,8 +83,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -79,8 +83,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> retv) const override { void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad"); retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands")); if (this->HasOutput("InnerCache")) {
retv->SetInput("InnerCache", this->Output("InnerCache")); retv->SetInput("InnerCache", this->Output("InnerCache"));
}
if (this->HasOutput("XShape")) {
// add if for compatibility.
retv->SetInput("Operands", this->Output("XShape")); // for memory save.
} else {
retv->SetInput("Operands", this->Input("Operands"));
}
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"), retv->SetOutput(framework::GradVarName("Operands"),
......
...@@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x, ...@@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out, MetaTensor* out,
std::vector<MetaTensor*> inner_cache) { std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction); LabelMap labeltype(LabelType::Reduction);
...@@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, ...@@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims)); out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype()); out->set_dtype(inputs[0]->dtype());
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
xshape[i]->set_dtype(inputs[i]->dtype());
}
}
} }
void ExpandInferMeta(const MetaTensor& x, void ExpandInferMeta(const MetaTensor& x,
......
...@@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x, ...@@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out, MetaTensor* out,
std::vector<MetaTensor*> inner_cache); std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
void ExpandInferMeta(const MetaTensor& x, void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape, const IntArray& shape,
......
...@@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx, ...@@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out, DenseTensor* out,
std::vector<DenseTensor*> cache); std::vector<DenseTensor*> inner_cache,
std::vector<DenseTensor*> xshape);
} // namespace phi } // namespace phi
...@@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx,
cache[0].ShareBufferWith(*(inner_cache[0])); cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1])); cache[1].ShareBufferWith(*(inner_cache[1]));
} }
EinsumKernelImpl<T, Context>(dev_ctx, EinsumKernelImpl<T, Context>(dev_ctx,
all_labels, all_labels,
operands_for_A, operands_for_A,
......
...@@ -459,7 +459,7 @@ DenseTensor PerformContraction( ...@@ -459,7 +459,7 @@ DenseTensor PerformContraction(
} }
// reduction // reduction
DenseTensor trans_t; DenseTensor trans_t;
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr && if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) { cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx])); trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!"; VLOG(5) << "Cache Used!";
...@@ -468,7 +468,7 @@ DenseTensor PerformContraction( ...@@ -468,7 +468,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type); dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>( trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr) if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t); cache[operand_idx]->ShareBufferWith(trans_t);
} }
auto mul_dims = GetShapeByType<int>(all_labels, auto mul_dims = GetShapeByType<int>(all_labels,
...@@ -599,6 +599,11 @@ void EinsumKernelImpl(const Context& dev_ctx, ...@@ -599,6 +599,11 @@ void EinsumKernelImpl(const Context& dev_ctx,
out); out);
// Reshape Procedure // Reshape Procedure
} else if (inputs.size() == 1) { } else if (inputs.size() == 1) {
if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if
// loading the program from v2.3.0
(*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward
// we can only see cached tensor.
}
auto reduce_A = PerformReduction<T, Context>(dev_ctx, auto reduce_A = PerformReduction<T, Context>(dev_ctx,
*inputs[0], *inputs[0],
label2perms[0], label2perms[0],
...@@ -627,7 +632,8 @@ void EinsumKernelRaw(const Context& dev_ctx, ...@@ -627,7 +632,8 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out, DenseTensor* out,
std::vector<DenseTensor*> cache) { std::vector<DenseTensor*> cache,
std::vector<DenseTensor*> xshape) {
std::vector<char> tmp; std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer // may have nullptr and the cache.size() is not equal to inputs.size(). refer
......
...@@ -18,7 +18,7 @@ namespace phi { ...@@ -18,7 +18,7 @@ namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"}); "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
...@@ -39,7 +39,9 @@ class TestEinsumBinary(OpTest): ...@@ -39,7 +39,9 @@ class TestEinsumBinary(OpTest):
'Out': 'Out':
out, out,
"InnerCache": [('cache_' + str(i), np.array([1.0])) "InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))] for i in range(len(self.operands))],
"XShape": [('xshape_' + str(i), np.array([1.0]))
for i in range(len(self.operands))],
} }
def init_input(self): def init_input(self):
...@@ -48,14 +50,13 @@ class TestEinsumBinary(OpTest): ...@@ -48,14 +50,13 @@ class TestEinsumBinary(OpTest):
self.inputs.append(np.random.random(s).astype(t)) self.inputs.append(np.random.random(s).astype(t))
def set_mandatory(self): def set_mandatory(self):
self.disable = False
self.shapes = [(10, 10, 20), (20, 6)] self.shapes = [(10, 10, 20), (20, 6)]
self.types = [np.float64, np.float64] self.types = [np.float64, np.float64]
self.equation = "mij,jk->ki" self.equation = "mij,jk->ki"
def test_check_output(self): def test_check_output(self):
if not self.disable: if not self.disable:
self.check_output(no_check_set=["InnerCache"]) self.check_output(no_check_set=["InnerCache", "XShape"])
def test_grad(self): def test_grad(self):
if not self.disable: if not self.disable:
......
...@@ -807,9 +807,9 @@ def gen_einsum_op(equation, *operands): ...@@ -807,9 +807,9 @@ def gen_einsum_op(equation, *operands):
if _in_legacy_dygraph(): if _in_legacy_dygraph():
# dygraph # dygraph
return _C_ops.einsum(operands, len(operands), 'equation', equation)[0] return _C_ops.einsum(operands, len(operands), len(operands), 'equation',
equation)[0]
# static graph
for inp in operands: for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
check_type(equation, 'equation', str, 'einsum') check_type(equation, 'equation', str, 'einsum')
...@@ -821,11 +821,16 @@ def gen_einsum_op(equation, *operands): ...@@ -821,11 +821,16 @@ def gen_einsum_op(equation, *operands):
helper.create_variable_for_type_inference(dtype=operands[0].dtype) helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands)) for i in range(len(operands))
] ]
xshape = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op(type='einsum', helper.append_op(type='einsum',
inputs={'Operands': operands}, inputs={'Operands': operands},
outputs={ outputs={
'Out': out, 'Out': out,
"InnerCache": caches "InnerCache": caches,
"XShape": xshape
}, },
attrs=attrs) attrs=attrs)
return out return out
......
...@@ -603,7 +603,7 @@ ...@@ -603,7 +603,7 @@
- api : einsum - api : einsum
args : (Tensor[] x, str equation) args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()} output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta : infer_meta :
func : EinsumInferMeta func : EinsumInferMeta
param : [x, equation] param : [x, equation]
......
#- backward_api : einsum_grad
#forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
#args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
#output : Tensor[](x_grad){x.size()}
#infer_meta :
#func : UnchangedMultiInferMeta
#param : [x]
#kernel :
#func : einsum_grad
- backward_api : abs_double_grad - backward_api : abs_double_grad
forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad) args : (Tensor x, Tensor grad_x_grad)
...@@ -616,12 +627,12 @@ ...@@ -616,12 +627,12 @@
skip_transform : out_w, out_w_grad skip_transform : out_w, out_w_grad
- backward_api : einsum_grad - backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()} output : Tensor[](x_grad){x.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
param : [x] param : [x_shape]
kernel : kernel :
func : einsum_grad func : einsum_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册