未验证 提交 638b69dc 编写于 作者: X xiongkun 提交者: GitHub

[Cherry pick] Einsum memory optimization PR #43397 (#43554)

* cherry pick from #43397

* fix code
上级 68d5c12b
......@@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
}
} // namespace egr
......@@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
using TupleOfTensorAndVector =
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>;
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
......
......@@ -40,6 +40,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsExtra()
.AsIntermediate();
AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in "
......@@ -58,8 +62,8 @@ class EinsumGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name));
ctx->ShareAllLoD(x_name, x_grad_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands"));
ctx->ShareAllLoD("Operands", x_grad_name);
}
protected:
......@@ -78,8 +82,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands"));
if (this->HasOutput("InnerCache")) {
retv->SetInput("InnerCache", this->Output("InnerCache"));
}
if (this->HasOutput("XShape")) {
// add if for compatibility.
retv->SetInput("Operands", this->Output("XShape")); // for memory save.
} else {
retv->SetInput("Operands", this->Input("Operands"));
}
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"),
......
......@@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
......@@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
xshape[i]->set_dtype(inputs[i]->dtype());
}
}
}
void ExpandInferMeta(const MetaTensor& x,
......
......@@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
......
......@@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
std::vector<DenseTensor*> inner_cache,
std::vector<DenseTensor*> xshape);
} // namespace phi
......@@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx,
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
......
......@@ -458,7 +458,7 @@ DenseTensor PerformContraction(
}
// reduction
DenseTensor trans_t;
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
......@@ -467,7 +467,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
......@@ -598,6 +598,11 @@ void EinsumKernelImpl(const Context& dev_ctx,
out);
// Reshape Procedure
} else if (inputs.size() == 1) {
if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if
// loading the program from v2.3.0
(*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward
// we can only see cached tensor.
}
auto reduce_A = PerformReduction<T, Context>(dev_ctx,
*inputs[0],
label2perms[0],
......@@ -626,7 +631,8 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<DenseTensor*> cache,
std::vector<DenseTensor*> xshape) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
......
......@@ -18,7 +18,7 @@ namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
}
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
......@@ -37,7 +37,9 @@ class TestEinsumBinary(OpTest):
self.outputs = {
'Out': out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
for i in range(len(self.operands))],
"XShape": [('xshape_' + str(i), np.array([1.0]))
for i in range(len(self.operands))],
}
def init_input(self):
......@@ -46,14 +48,13 @@ class TestEinsumBinary(OpTest):
self.inputs.append(np.random.random(s).astype(t))
def set_mandatory(self):
self.disable = False
self.shapes = [(10, 10, 20), (20, 6)]
self.types = [np.float64, np.float64]
self.equation = "mij,jk->ki"
def test_check_output(self):
if not self.disable:
self.check_output(no_check_set=["InnerCache"])
self.check_output(no_check_set=["InnerCache", "XShape"])
def test_grad(self):
if not self.disable:
......
......@@ -802,9 +802,10 @@ def gen_einsum_op(equation, *operands):
if _in_legacy_dygraph():
# dygraph
return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
return _C_ops.einsum(operands,
len(operands), len(operands), 'equation',
equation)[0]
# static graph
for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
check_type(equation, 'equation', str, 'einsum')
......@@ -816,11 +817,16 @@ def gen_einsum_op(equation, *operands):
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
xshape = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op(
type='einsum',
inputs={'Operands': operands},
outputs={'Out': out,
"InnerCache": caches},
"InnerCache": caches,
"XShape": xshape},
attrs=attrs)
return out
......
......@@ -547,7 +547,7 @@
- api : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
param : [x, equation]
......
- backward_api : abs_double_grad
forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : abs_double_grad
data_transform:
skip_transform : grad_x_grad
- backward_api : abs_grad
forward : abs (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -447,12 +459,12 @@
skip_transform : out_w, out_w_grad
- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x_shape.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
param : [x_shape]
kernel :
func : einsum_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册