From 638b69dcecc6320d3b81ecc734ae218f456ad505 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 20 Jun 2022 12:15:49 +0800 Subject: [PATCH] [Cherry pick] Einsum memory optimization PR #43397 (#43554) * cherry pick from #43397 * fix code --- paddle/fluid/eager/nan_inf_utils.cc | 1 + paddle/fluid/eager/nan_inf_utils.h | 3 ++- paddle/fluid/operators/einsum_op.cc | 19 ++++++++++++++---- paddle/phi/infermeta/unary.cc | 9 ++++++++- paddle/phi/infermeta/unary.h | 3 ++- paddle/phi/kernels/einsum_kernel.h | 3 ++- paddle/phi/kernels/impl/einsum_grad_impl.h | 1 - paddle/phi/kernels/impl/einsum_impl.h | 12 ++++++++--- paddle/phi/ops/compat/einsum_sig.cc | 2 +- .../fluid/tests/unittests/test_einsum_op.py | 7 ++++--- python/paddle/tensor/einsum.py | 12 ++++++++--- python/paddle/utils/code_gen/api.yaml | 2 +- python/paddle/utils/code_gen/backward.yaml | 20 +++++++++++++++---- 13 files changed, 70 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/eager/nan_inf_utils.cc b/paddle/fluid/eager/nan_inf_utils.cc index d1c5983a370..0ed1a198c91 100644 --- a/paddle/fluid/eager/nan_inf_utils.cc +++ b/paddle/fluid/eager/nan_inf_utils.cc @@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const TupleOfTensorAndVector& tensors) { CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); CheckTensorHasNanOrInf(api_name, std::get<1>(tensors)); + CheckTensorHasNanOrInf(api_name, std::get<2>(tensors)); } } // namespace egr diff --git a/paddle/fluid/eager/nan_inf_utils.h b/paddle/fluid/eager/nan_inf_utils.h index a411504fa49..815e3bd6cd1 100644 --- a/paddle/fluid/eager/nan_inf_utils.h +++ b/paddle/fluid/eager/nan_inf_utils.h @@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple; using TupleOfFiveTensors = std::tuple; using TupleOfSixTensors = std::tuple; -using TupleOfTensorAndVector = std::tuple>; +using TupleOfTensorAndVector = + std::tuple, std::vector>; void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 6da0045443c..c0566aaeb2f 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -40,6 +40,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { .AsExtra() .AsIntermediate(); + AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.") + .AsDuplicable() + .AsExtra() + .AsIntermediate(); AddAttr("equation", "(string) A einsum equation. such as `ij,jk->ik`" "There must have `->` and the number of operands in " @@ -58,8 +62,8 @@ class EinsumGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { auto x_name = "Operands"; auto x_grad_name = framework::GradVarName(x_name); - ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name)); - ctx->ShareAllLoD(x_name, x_grad_name); + ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands")); + ctx->ShareAllLoD("Operands", x_grad_name); } protected: @@ -78,8 +82,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr retv) const override { retv->SetType("einsum_grad"); - retv->SetInput("Operands", this->Input("Operands")); - retv->SetInput("InnerCache", this->Output("InnerCache")); + if (this->HasOutput("InnerCache")) { + retv->SetInput("InnerCache", this->Output("InnerCache")); + } + if (this->HasOutput("XShape")) { + // add if for compatibility. + retv->SetInput("Operands", this->Output("XShape")); // for memory save. + } else { + retv->SetInput("Operands", this->Input("Operands")); + } retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetAttrMap(this->Attrs()); retv->SetOutput(framework::GradVarName("Operands"), diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 980b4219c51..43265cf5e6d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out, - std::vector inner_cache) { + std::vector inner_cache, + std::vector xshape) { // collect the following informations to prepare einsum. LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); @@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector& inputs, VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); out->set_dims(make_ddim(output_dims)); out->set_dtype(inputs[0]->dtype()); + for (size_t i = 0; i < xshape.size(); ++i) { + if (xshape[i] != nullptr) { + xshape[i]->set_dims(inputs[i]->dims()); + xshape[i]->set_dtype(inputs[i]->dtype()); + } + } } void ExpandInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index e141acb2ea2..2c21930e03a 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out, - std::vector inner_cache); + std::vector inner_cache, + std::vector xshape); void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h index 87df2b1c64a..569cf7a55af 100644 --- a/paddle/phi/kernels/einsum_kernel.h +++ b/paddle/phi/kernels/einsum_kernel.h @@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx, const std::vector& inputs, const std::string& equation, DenseTensor* out, - std::vector cache); + std::vector inner_cache, + std::vector xshape); } // namespace phi diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index a72db326807..a04185a0c53 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx, cache[0].ShareBufferWith(*(inner_cache[0])); cache[1].ShareBufferWith(*(inner_cache[1])); } - EinsumKernelImpl(dev_ctx, all_labels, operands_for_A, diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index bfbd6e0c51c..fb0ea2132db 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -458,7 +458,7 @@ DenseTensor PerformContraction( } // reduction DenseTensor trans_t; - if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr && + if (use_cache && cache[operand_idx] != nullptr && cache[operand_idx]->IsInitialized()) { trans_t.ShareBufferWith(*(cache[operand_idx])); VLOG(5) << "Cache Used!"; @@ -467,7 +467,7 @@ DenseTensor PerformContraction( dev_ctx, t, perm, all_labels, ellipsis, label2type); trans_t = PerformTranspose( dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); - if (FLAGS_einsum_opt && cache[operand_idx] != nullptr) + if (cache[operand_idx] != nullptr) cache[operand_idx]->ShareBufferWith(trans_t); } auto mul_dims = GetShapeByType(all_labels, @@ -598,6 +598,11 @@ void EinsumKernelImpl(const Context& dev_ctx, out); // Reshape Procedure } else if (inputs.size() == 1) { + if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if + // loading the program from v2.3.0 + (*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward + // we can only see cached tensor. + } auto reduce_A = PerformReduction(dev_ctx, *inputs[0], label2perms[0], @@ -626,7 +631,8 @@ void EinsumKernelRaw(const Context& dev_ctx, const std::vector& inputs, const std::string& equation, DenseTensor* out, - std::vector cache) { + std::vector cache, + std::vector xshape) { std::vector tmp; // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output // may have nullptr and the cache.size() is not equal to inputs.size(). refer diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index 5e45bcf97ce..4fd31c1a2d8 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -18,7 +18,7 @@ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( - "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"}); + "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py index 1a4ae54afef..4cec7d63738 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -37,7 +37,9 @@ class TestEinsumBinary(OpTest): self.outputs = { 'Out': out, "InnerCache": [('cache_' + str(i), np.array([1.0])) - for i in range(len(self.operands))] + for i in range(len(self.operands))], + "XShape": [('xshape_' + str(i), np.array([1.0])) + for i in range(len(self.operands))], } def init_input(self): @@ -46,14 +48,13 @@ class TestEinsumBinary(OpTest): self.inputs.append(np.random.random(s).astype(t)) def set_mandatory(self): - self.disable = False self.shapes = [(10, 10, 20), (20, 6)] self.types = [np.float64, np.float64] self.equation = "mij,jk->ki" def test_check_output(self): if not self.disable: - self.check_output(no_check_set=["InnerCache"]) + self.check_output(no_check_set=["InnerCache", "XShape"]) def test_grad(self): if not self.disable: diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 49cc426a00f..a5ad3c37a4f 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -802,9 +802,10 @@ def gen_einsum_op(equation, *operands): if _in_legacy_dygraph(): # dygraph - return _C_ops.einsum(operands, len(operands), 'equation', equation)[0] + return _C_ops.einsum(operands, + len(operands), len(operands), 'equation', + equation)[0] - # static graph for inp in operands: check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') check_type(equation, 'equation', str, 'einsum') @@ -816,11 +817,16 @@ def gen_einsum_op(equation, *operands): helper.create_variable_for_type_inference(dtype=operands[0].dtype) for i in range(len(operands)) ] + xshape = [ + helper.create_variable_for_type_inference(dtype=operands[0].dtype) + for i in range(len(operands)) + ] helper.append_op( type='einsum', inputs={'Operands': operands}, outputs={'Out': out, - "InnerCache": caches}, + "InnerCache": caches, + "XShape": xshape}, attrs=attrs) return out diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 845f6b6ba2f..6f3351bb6c0 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -547,7 +547,7 @@ - api : einsum args : (Tensor[] x, str equation) - output : Tensor, Tensor[]{x.size()} + output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} infer_meta : func : EinsumInferMeta param : [x, equation] diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index e16c57d4d89..b8cb1340d26 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1,3 +1,15 @@ +- backward_api : abs_double_grad + forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x) + args : (Tensor x, Tensor grad_x_grad) + output : Tensor(grad_out_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : abs_double_grad + data_transform: + skip_transform : grad_x_grad + - backward_api : abs_grad forward : abs (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -447,12 +459,12 @@ skip_transform : out_w, out_w_grad - backward_api : einsum_grad - forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache) - args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation) - output : Tensor[](x_grad){x.size()} + forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) + args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) + output : Tensor[](x_grad){x_shape.size()} infer_meta : func : UnchangedMultiInferMeta - param : [x] + param : [x_shape] kernel : func : einsum_grad -- GitLab