未验证 提交 71b046cd 编写于 作者: X xiongkun 提交者: GitHub

[EinsumOp] Optimize the backward speed of EinsumOp (#42663)

* change logic for optimize

* modifty

* optimize the backward speed of EinsumOp

* add cache optimizer for einsum op

* EinsumOp: fix new dygraph mode error

* fix bug

* change Cache->InnerCache

* fix code

* fix

* add nan inf utils for einsum op

* add as_extra

* Compatible with v2.3 EinsumOp

* remove dispensable
上级 e5fc68b2
......@@ -110,4 +110,10 @@ void CheckTensorHasNanOrInf(
}
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
}
} // namespace egr
......@@ -31,6 +31,7 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
......@@ -52,6 +53,9 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors);
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
......
......@@ -33,6 +33,13 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Operands", "(TensorList), The input tensor of einsum op.")
.AsDuplicable();
AddOutput("Out", "(Tensor), The output tensor of einsum op.");
AddOutput(
"InnerCache",
"(Tensor), The cache of the forward transpose tensors: tA and tB.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in "
......@@ -72,6 +79,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput("InnerCache", this->Output("InnerCache"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"),
......
......@@ -401,7 +401,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
......
......@@ -82,7 +82,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out);
MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
......
......@@ -17,4 +17,5 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {}
PD_REGISTER_KERNEL(
einsum, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, float, double) {}
......@@ -21,6 +21,7 @@ namespace phi {
template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& inner_cache,
const DenseTensor& out_grad,
const std::string& equation,
std::vector<DenseTensor*> x_grad);
......
......@@ -24,4 +24,11 @@ void EinsumKernel(const Context& dev_ctx,
const std::string& equation,
DenseTensor* out);
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
} // namespace phi
......@@ -18,5 +18,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_grad_impl.h"
PD_REGISTER_KERNEL(
einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {}
PD_REGISTER_KERNEL(einsum_grad,
GPU,
ALL_LAYOUT,
phi::EinsumGradKernel,
float,
double,
phi::dtype::float16) {}
......@@ -18,4 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {}
PD_REGISTER_KERNEL(einsum,
GPU,
ALL_LAYOUT,
phi::EinsumKernelRaw,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/kernels/tile_kernel.h"
......@@ -55,7 +56,13 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
}
t.Resize(make_ddim(resize_dims));
DenseTensor after_tile;
if (std::all_of(repeat_times.begin(), repeat_times.end(), [](int x) {
return x == 1;
})) {
after_tile = t;
} else {
TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile);
}
size_t n_ellipsis_idx = op_label.find(".", 0);
if (n_ellipsis_idx != std::string::npos) {
// may be we need reduce. broadcast_dims is not equal to ellipsis dims.
......@@ -91,10 +98,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
template <typename T, typename Context>
void EinsumGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& inner_cache,
const DenseTensor& out_grad,
const std::string& equation,
std::vector<DenseTensor*> x_grad) {
VLOG(5) << "Start EisumGradKernel:";
VLOG(5) << "Start EinsumGradKernel:";
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(x.size(), LabelMap(-1));
......@@ -162,8 +170,33 @@ void EinsumGradKernel(const Context& dev_ctx,
operands_for_B.push_back(x[0]);
DenseTensor before_tile;
EinsumKernel<T, Context>(dev_ctx, operands_for_A, equation_for_A, &dA);
EinsumKernel<T, Context>(dev_ctx, operands_for_B, equation_for_B, &dB);
std::vector<DenseTensor> cache(3); // set empty; TA, TB, TdC
if (inner_cache.size() >
0) { // for compatibility, we can load and run v2.3 EinsumOp.
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
equation_for_A,
&dA,
{&cache[1], &cache[2]},
false);
EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_B,
equation_for_B,
&dB,
{&cache[2], &cache[0]},
false);
// release the cache tensor dTC to save memory right now. they are useless
// now.
cache.clear();
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
labelshape,
......
......@@ -137,7 +137,6 @@ inline std::vector<char> TransformLabelsOrder(
std::vector<char> tmp;
for (int c : all_labels) {
if (type[c] == cnt_type) tmp.push_back(c);
std::sort(tmp.begin(), tmp.end());
}
ret.insert(ret.end(), tmp.begin(), tmp.end());
}
......@@ -176,6 +175,15 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
(*label2type)['.'] = LabelType::Batch;
if (sorted_labels->size()) {
std::set<char> exist(all.begin(), all.end());
all.clear();
std::for_each(
sorted_labels->begin(), sorted_labels->end(), [&exist, &all](char c) {
if (exist.count(c)) all.push_back(c);
});
}
*sorted_labels = TransformLabelsOrder(all,
*label2type,
{LabelType::Batch,
......@@ -409,7 +417,8 @@ DenseTensor PerformContraction(
const LabelMap& label2shape,
const std::vector<std::vector<int>>& ellipsis_dims,
const std::vector<int>& broadcast_dims,
std::vector<DenseTensor*> cache) {
std::vector<DenseTensor*> cache,
bool use_cache) {
// Get All the Batches, so perm is
auto all_valid = LabelMap(1);
auto recover_dim = GetShapeByType<int>(all_labels,
......@@ -447,13 +456,16 @@ DenseTensor PerformContraction(
}
// reduction
DenseTensor trans_t;
if (cache[operand_idx]->IsInitialized()) {
if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
} else {
auto reduct_t = PerformReduction<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
......@@ -515,18 +527,23 @@ void TransposeToOutput(const Context& dev_ctx,
axis.push_back(it - all_labels.begin() + offset);
}
}
if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans);
if (is_no_need_transpose(axis)) {
output->ShareBufferWith(to_trans);
return;
}
VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ",");
return TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
}
template <typename T, typename Context>
void EinsumKernelImpl(const Context& dev_ctx,
const std::vector<char>& forward_all_labels,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<DenseTensor*> cache,
bool is_forward = true) {
ValidationCheck(equation);
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
......@@ -542,6 +559,9 @@ void EinsumKernelImpl(const Context& dev_ctx,
input_dims.push_back(i->dims());
}
std::string right;
if (!is_forward) {
all_labels = forward_all_labels;
}
ParseEinsumEquation(equation,
input_dims,
&labelshape,
......@@ -557,7 +577,6 @@ void EinsumKernelImpl(const Context& dev_ctx,
auto& A = inputs[0];
auto& B = inputs[1];
// Reduction and Contract Procedure
dev_ctx.template Alloc<T>(out);
auto after_contraction = PerformContraction<T, Context>(dev_ctx,
*A,
*B,
......@@ -567,7 +586,8 @@ void EinsumKernelImpl(const Context& dev_ctx,
labelshape,
ellipsis_dims,
broadcast_dims,
cache);
cache,
!is_forward);
TransposeToOutput<T, Context>(dev_ctx,
after_contraction,
right,
......@@ -599,18 +619,37 @@ void EinsumKernelImpl(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
// to BuildPhiKernelContext for details.
int diff = inputs.size() - cache.size();
for (int i = 0; i < diff; ++i) {
cache.push_back(nullptr);
}
EinsumKernelImpl<T, Context>(
dev_ctx, tmp, inputs, equation, out, cache, /*forward=*/true);
}
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out) {
std::vector<DenseTensor> cache(inputs.size()); // set empty; TA, TB, TdC
std::vector<char> place_holder;
std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC
for (size_t i = 0; i < inputs.size(); ++i) {
cache_tensor[i] = &cache[i];
cache_tensor[i] = nullptr;
}
EinsumKernelImpl<T, Context>(dev_ctx, inputs, equation, out, cache_tensor);
EinsumKernelImpl<T, Context>(
dev_ctx, place_holder, inputs, equation, out, cache_tensor, true);
}
} // namespace phi
......@@ -17,14 +17,15 @@ limitations under the License. */
namespace phi {
KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"});
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
}
KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("einsum_grad",
{"Operands", {"Out@GRAD"}},
{"Operands", "InnerCache", "Out@GRAD"},
{"equation"},
{{"Operands@GRAD"}});
{"Operands@GRAD"});
}
} // namespace phi
......
......@@ -34,7 +34,11 @@ class TestEinsumBinary(OpTest):
self.operands.append(("x" + str(idx), inp))
self.inputs = {"Operands": self.operands}
self.attrs = {"equation": self.equation}
self.outputs = {'Out': out}
self.outputs = {
'Out': out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
}
def init_input(self):
self.inputs = []
......@@ -49,7 +53,7 @@ class TestEinsumBinary(OpTest):
def test_check_output(self):
if not self.disable:
self.check_output()
self.check_output(no_check_set=["InnerCache"])
def test_grad(self):
if not self.disable:
......
......@@ -35,4 +35,5 @@ no_check_set_white_list = [
'eigh',
'eigvalsh',
'class_center_sample',
'einsum',
]
......@@ -798,11 +798,12 @@ def gen_einsum_op(equation, *operands):
"""
assert len(operands) <= 2, "Only support two operands in EinsumOp."
if in_dygraph_mode():
return _C_ops.final_state_einsum(operands, equation)
return _C_ops.final_state_einsum(operands, equation)[0]
if _in_legacy_dygraph():
# dygraph
return _C_ops.einsum(operands, 'equation', equation)
return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
# static graph
for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
......@@ -811,11 +812,16 @@ def gen_einsum_op(equation, *operands):
out = helper.create_variable_for_type_inference(dtype=operands[0].dtype)
attrs = dict()
attrs['equation'] = equation
caches = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op(
type='einsum',
inputs={'Operands': operands},
outputs={'Out': out},
attrs=attrs, )
outputs={'Out': out,
"InnerCache": caches},
attrs=attrs)
return out
......
......@@ -585,7 +585,7 @@
- api : einsum
args : (Tensor[] x, str equation)
output : Tensor
output : Tensor, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
param : [x, equation]
......
......@@ -224,16 +224,18 @@ class BaseAPI(object):
if len(temp_list) == 1:
out_type, out_name, size_expr = parse_output_item(temp_list[0])
return [out_type], [out_name], size_expr
return [out_type], [out_name], [size_expr]
else:
out_type_list = []
out_name_list = []
out_size_expr_list = []
for output_item in temp_list:
out_type, out_name, size_expr = parse_output_item(output_item)
out_type_list.append(out_type)
out_name_list.append(out_name)
out_size_expr_list.append(size_expr)
return out_type_list, out_name_list, size_expr
return out_type_list, out_name_list, out_size_expr_list
def parse_infer_meta(self, infer_meta_config):
infer_meta = infer_meta_config
......
......@@ -111,10 +111,10 @@ class ForwardAPI(BaseAPI):
{code_indent} {return_type} api_output{inplace_assign};"""
if return_type == 'std::vector<Tensor>':
assert self.outputs['out_size_expr'] is not None, \
assert self.outputs['out_size_expr'][0] is not None, \
f"{api_name}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr']}, kernel_backend, &api_output);"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);"""
else:
output_create = output_create + f"""
......
......@@ -552,8 +552,8 @@
skip_transform : out_w, out_w_grad
- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, str equation)
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : UnchangedMultiInferMeta
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册