未验证 提交 71b046cd 编写于 作者: X xiongkun 提交者: GitHub

[EinsumOp] Optimize the backward speed of EinsumOp (#42663)

* change logic for optimize

* modifty

* optimize the backward speed of EinsumOp

* add cache optimizer for einsum op

* EinsumOp: fix new dygraph mode error

* fix bug

* change Cache->InnerCache

* fix code

* fix

* add nan inf utils for einsum op

* add as_extra

* Compatible with v2.3 EinsumOp

* remove dispensable
上级 e5fc68b2
...@@ -110,4 +110,10 @@ void CheckTensorHasNanOrInf( ...@@ -110,4 +110,10 @@ void CheckTensorHasNanOrInf(
} }
} }
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
}
} // namespace egr } // namespace egr
...@@ -31,6 +31,7 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>; ...@@ -31,6 +31,7 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>; using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors = using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>; std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
...@@ -52,6 +53,9 @@ void CheckTensorHasNanOrInf(const std::string& api_name, ...@@ -52,6 +53,9 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
void CheckTensorHasNanOrInf(const std::string& api_name, void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors); const std::vector<Tensor>& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors);
void CheckTensorHasNanOrInf( void CheckTensorHasNanOrInf(
const std::string& api_name, const std::string& api_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>, const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
......
...@@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Operands", "(TensorList), The input tensor of einsum op.") AddInput("Operands", "(TensorList), The input tensor of einsum op.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor), The output tensor of einsum op."); AddOutput("Out", "(Tensor), The output tensor of einsum op.");
AddOutput(
"InnerCache",
"(Tensor), The cache of the forward transpose tensors: tA and tB.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation", AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`" "(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in " "There must have `->` and the number of operands in "
...@@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> retv) const override { void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad"); retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands")); retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput("InnerCache", this->Output("InnerCache"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"), retv->SetOutput(framework::GradVarName("Operands"),
......
...@@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x, ...@@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out) { MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction); LabelMap labeltype(LabelType::Reduction);
......
...@@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x, ...@@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out); MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
void ExpandInferMeta(const MetaTensor& x, void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape, const IntArray& shape,
......
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} PD_REGISTER_KERNEL(
einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {}
...@@ -21,6 +21,7 @@ namespace phi { ...@@ -21,6 +21,7 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx, void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& inner_cache,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& equation, const std::string& equation,
std::vector<DenseTensor*> x_grad); std::vector<DenseTensor*> x_grad);
......
...@@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx,
const std::string& equation, const std::string& equation,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
} // namespace phi } // namespace phi
...@@ -18,5 +18,10 @@ ...@@ -18,5 +18,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h" #include "paddle/phi/kernels/impl/einsum_grad_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(einsum_grad,
einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::EinsumGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -18,4 +18,11 @@ ...@@ -18,4 +18,11 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/impl/einsum_impl.h" #include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/kernels/tile_kernel.h" #include "paddle/phi/kernels/tile_kernel.h"
...@@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, ...@@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
} }
t.Resize(make_ddim(resize_dims)); t.Resize(make_ddim(resize_dims));
DenseTensor after_tile; DenseTensor after_tile;
TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile); if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) {
return x == 1;
})) {
after_tile = t;
} else {
TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile);
}
size_t n_ellipsis_idx = op_label.find(".", 0); size_t n_ellipsis_idx = op_label.find(".", 0);
if (n_ellipsis_idx != std::string::npos) { if (n_ellipsis_idx != std::string::npos) {
// may be we need reduce. broadcast_dims is not equal to ellipsis dims. // may be we need reduce. broadcast_dims is not equal to ellipsis dims.
...@@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, ...@@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx, void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x, const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& inner_cache,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const std::string& equation, const std::string& equation,
std::vector<DenseTensor*> x_grad) { std::vector<DenseTensor*> x_grad) {
VLOG(5) << "Start EisumGradKernel:"; VLOG(5) << "Start EinsumGradKernel:";
LabelMap labelshape(0); LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction); LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(x.size(), LabelMap(-1)); std::vector<LabelMap> label2perms(x.size(), LabelMap(-1));
...@@ -162,8 +170,33 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -162,8 +170,33 @@ void EinsumGradKernel(const Context& dev_ctx,
operands_for_B.push_back(x[0]); operands_for_B.push_back(x[0]);
DenseTensor before_tile; DenseTensor before_tile;
EinsumKernel<T, Context>(dev_ctx, operands_for_A, equation_for_A, &dA);
EinsumKernel<T, Context>(dev_ctx, operands_for_B, equation_for_B, &dB); std::vector<DenseTensor> cache(3); // set empty; TA, TB, TdC
if (inner_cache.size() >
0) { // for compatibility, we can load and run v2.3 EinsumOp.
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
equation_for_A,
&dA,
{&cache[1], &cache[2]},
false);
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_B,
equation_for_B,
&dB,
{&cache[2], &cache[0]},
false);
// release the cache tensor dTC to save memory right now. they are useless
// now.
cache.clear();
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx, *(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype, labeltype,
labelshape, labelshape,
......
...@@ -137,7 +137,6 @@ inline std::vector<char> TransformLabelsOrder( ...@@ -137,7 +137,6 @@ inline std::vector<char> TransformLabelsOrder(
std::vector<char> tmp; std::vector<char> tmp;
for (int c : all_labels) { for (int c : all_labels) {
if (type[c] == cnt_type) tmp.push_back(c); if (type[c] == cnt_type) tmp.push_back(c);
std::sort(tmp.begin(), tmp.end());
} }
ret.insert(ret.end(), tmp.begin(), tmp.end()); ret.insert(ret.end(), tmp.begin(), tmp.end());
} }
...@@ -176,6 +175,15 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels, ...@@ -176,6 +175,15 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
(*label2type)['.'] = LabelType::Batch; (*label2type)['.'] = LabelType::Batch;
if (sorted_labels->size()) {
std::set<char> exist(all.begin(), all.end());
all.clear();
std::for_each(
sorted_labels->begin(), sorted_labels->end(), [&exist, &all](char c) {
if (exist.count(c)) all.push_back(c);
});
}
*sorted_labels = TransformLabelsOrder(all, *sorted_labels = TransformLabelsOrder(all,
*label2type, *label2type,
{LabelType::Batch, {LabelType::Batch,
...@@ -409,7 +417,8 @@ DenseTensor PerformContraction( ...@@ -409,7 +417,8 @@ DenseTensor PerformContraction(
const LabelMap& label2shape, const LabelMap& label2shape,
const std::vector<std::vector<int>>& ellipsis_dims, const std::vector<std::vector<int>>& ellipsis_dims,
const std::vector<int>& broadcast_dims, const std::vector<int>& broadcast_dims,
std::vector<DenseTensor*> cache) { std::vector<DenseTensor*> cache,
bool use_cache) {
// Get All the Batches, so perm is // Get All the Batches, so perm is
auto all_valid = LabelMap(1); auto all_valid = LabelMap(1);
auto recover_dim = GetShapeByType<int>(all_labels, auto recover_dim = GetShapeByType<int>(all_labels,
...@@ -447,14 +456,17 @@ DenseTensor PerformContraction( ...@@ -447,14 +456,17 @@ DenseTensor PerformContraction(
} }
// reduction // reduction
DenseTensor trans_t; DenseTensor trans_t;
if (cache[operand_idx]->IsInitialized()) { if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx])); trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
} else { } else {
auto reduct_t = PerformReduction<T, Context>( auto reduct_t = PerformReduction<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type); dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>( trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
cache[operand_idx]->ShareBufferWith(trans_t); if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
} }
auto mul_dims = GetShapeByType<int>(all_labels, auto mul_dims = GetShapeByType<int>(all_labels,
label2type, label2type,
...@@ -515,18 +527,23 @@ void TransposeToOutput(const Context& dev_ctx, ...@@ -515,18 +527,23 @@ void TransposeToOutput(const Context& dev_ctx,
axis.push_back(it - all_labels.begin() + offset); axis.push_back(it - all_labels.begin() + offset);
} }
} }
if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans); if (is_no_need_transpose(axis)) {
output->ShareBufferWith(to_trans);
return;
}
VLOG(5) << "call TransposeToOutput: with axis: " VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ","); << paddle::string::join_strings(axis, ",");
return TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output); TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
} }
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernelImpl(const Context& dev_ctx, void EinsumKernelImpl(const Context& dev_ctx,
const std::vector<char>& forward_all_labels,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out, DenseTensor* out,
std::vector<DenseTensor*> cache) { std::vector<DenseTensor*> cache,
bool is_forward = true) {
ValidationCheck(equation); ValidationCheck(equation);
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
...@@ -542,6 +559,9 @@ void EinsumKernelImpl(const Context& dev_ctx, ...@@ -542,6 +559,9 @@ void EinsumKernelImpl(const Context& dev_ctx,
input_dims.push_back(i->dims()); input_dims.push_back(i->dims());
} }
std::string right; std::string right;
if (!is_forward) {
all_labels = forward_all_labels;
}
ParseEinsumEquation(equation, ParseEinsumEquation(equation,
input_dims, input_dims,
&labelshape, &labelshape,
...@@ -557,7 +577,6 @@ void EinsumKernelImpl(const Context& dev_ctx, ...@@ -557,7 +577,6 @@ void EinsumKernelImpl(const Context& dev_ctx,
auto& A = inputs[0]; auto& A = inputs[0];
auto& B = inputs[1]; auto& B = inputs[1];
// Reduction and Contract Procedure // Reduction and Contract Procedure
dev_ctx.template Alloc<T>(out);
auto after_contraction = PerformContraction<T, Context>(dev_ctx, auto after_contraction = PerformContraction<T, Context>(dev_ctx,
*A, *A,
*B, *B,
...@@ -567,7 +586,8 @@ void EinsumKernelImpl(const Context& dev_ctx, ...@@ -567,7 +586,8 @@ void EinsumKernelImpl(const Context& dev_ctx,
labelshape, labelshape,
ellipsis_dims, ellipsis_dims,
broadcast_dims, broadcast_dims,
cache); cache,
!is_forward);
TransposeToOutput<T, Context>(dev_ctx, TransposeToOutput<T, Context>(dev_ctx,
after_contraction, after_contraction,
right, right,
...@@ -599,18 +619,37 @@ void EinsumKernelImpl(const Context& dev_ctx, ...@@ -599,18 +619,37 @@ void EinsumKernelImpl(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
// to BuildPhiKernelContext for details.
int diff = inputs.size() - cache.size();
for (int i = 0; i < diff; ++i) {
cache.push_back(nullptr);
}
EinsumKernelImpl<T, Context>(
dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true);
}
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx, void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out) { DenseTensor* out) {
std::vector<DenseTensor> cache(inputs.size()); // set empty; TA, TB, TdC std::vector<char> place_holder;
std::vector<DenseTensor*> cache_tensor( std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC inputs.size()); // set empty; TA, TB, TdC
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cache_tensor[i] = &cache[i]; cache_tensor[i] = nullptr;
} }
EinsumKernelImpl<T, Context>(dev_ctx, inputs, equation, out, cache_tensor); EinsumKernelImpl<T, Context>(
dev_ctx, place_holder, inputs, equation, out, cache_tensor, true);
} }
} // namespace phi } // namespace phi
...@@ -17,14 +17,15 @@ limitations under the License. */ ...@@ -17,14 +17,15 @@ limitations under the License. */
namespace phi { namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
} }
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum_grad", return KernelSignature("einsum_grad",
{"Operands", {"Out@GRAD"}}, {"Operands", "InnerCache", "Out@GRAD"},
{"equation"}, {"equation"},
{{"Operands@GRAD"}}); {"Operands@GRAD"});
} }
} // namespace phi } // namespace phi
......
...@@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest): ...@@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest):
self.operands.append(("x" + str(idx), inp)) self.operands.append(("x" + str(idx), inp))
self.inputs = {"Operands": self.operands} self.inputs = {"Operands": self.operands}
self.attrs = {"equation": self.equation} self.attrs = {"equation": self.equation}
self.outputs = {'Out': out} self.outputs = {
'Out': out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
}
def init_input(self): def init_input(self):
self.inputs = [] self.inputs = []
...@@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest): ...@@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest):
def test_check_output(self): def test_check_output(self):
if not self.disable: if not self.disable:
self.check_output() self.check_output(no_check_set=["InnerCache"])
def test_grad(self): def test_grad(self):
if not self.disable: if not self.disable:
......
...@@ -35,4 +35,5 @@ no_check_set_white_list = [ ...@@ -35,4 +35,5 @@ no_check_set_white_list = [
'eigh', 'eigh',
'eigvalsh', 'eigvalsh',
'class_center_sample', 'class_center_sample',
'einsum',
] ]
...@@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands): ...@@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands):
""" """
assert len(operands) <= 2, "Only support two operands in EinsumOp." assert len(operands) <= 2, "Only support two operands in EinsumOp."
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_einsum(operands, equation) return _C_ops.final_state_einsum(operands, equation)[0]
if _in_legacy_dygraph(): if _in_legacy_dygraph():
# dygraph # dygraph
return _C_ops.einsum(operands, 'equation', equation) return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
# static graph # static graph
for inp in operands: for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
...@@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands): ...@@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands):
out = helper.create_variable_for_type_inference(dtype=operands[0].dtype) out = helper.create_variable_for_type_inference(dtype=operands[0].dtype)
attrs = dict() attrs = dict()
attrs['equation'] = equation attrs['equation'] = equation
caches = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op( helper.append_op(
type='einsum', type='einsum',
inputs={'Operands': operands}, inputs={'Operands': operands},
outputs={'Out': out}, outputs={'Out': out,
attrs=attrs, ) "InnerCache": caches},
attrs=attrs)
return out return out
......
...@@ -585,7 +585,7 @@ ...@@ -585,7 +585,7 @@
- api : einsum - api : einsum
args : (Tensor[] x, str equation) args : (Tensor[] x, str equation)
output : Tensor output : Tensor, Tensor[]{x.size()}
infer_meta : infer_meta :
func : EinsumInferMeta func : EinsumInferMeta
param : [x, equation] param : [x, equation]
......
...@@ -224,16 +224,18 @@ class BaseAPI(object): ...@@ -224,16 +224,18 @@ class BaseAPI(object):
if len(temp_list) == 1: if len(temp_list) == 1:
out_type, out_name, size_expr = parse_output_item(temp_list[0]) out_type, out_name, size_expr = parse_output_item(temp_list[0])
return [out_type], [out_name], size_expr return [out_type], [out_name], [size_expr]
else: else:
out_type_list = [] out_type_list = []
out_name_list = [] out_name_list = []
out_size_expr_list = []
for output_item in temp_list: for output_item in temp_list:
out_type, out_name, size_expr = parse_output_item(output_item) out_type, out_name, size_expr = parse_output_item(output_item)
out_type_list.append(out_type) out_type_list.append(out_type)
out_name_list.append(out_name) out_name_list.append(out_name)
out_size_expr_list.append(size_expr)
return out_type_list, out_name_list, size_expr return out_type_list, out_name_list, out_size_expr_list
def parse_infer_meta(self, infer_meta_config): def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config infer_meta = infer_meta_config
......
...@@ -111,10 +111,10 @@ class ForwardAPI(BaseAPI): ...@@ -111,10 +111,10 @@ class ForwardAPI(BaseAPI):
{code_indent} {return_type} api_output{inplace_assign};""" {code_indent} {return_type} api_output{inplace_assign};"""
if return_type == 'std::vector<Tensor>': if return_type == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \ assert self.outputs['out_size_expr'][0] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);""" {code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
......
...@@ -552,8 +552,8 @@ ...@@ -552,8 +552,8 @@
skip_transform : out_w, out_w_grad skip_transform : out_w, out_w_grad
- backward_api : einsum_grad - backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out) forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor out_grad, str equation) args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()} output : Tensor[](x_grad){x.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册