未验证 提交 cf198dc9 编写于 作者: X xiongkun 提交者: GitHub

[EinsumOp] Polish forward logic and backward logic for optimize (#42603)

* change logic for optimize

* modifty
上级 02e5c4be
......@@ -148,14 +148,16 @@ void EinsumGradKernel(const Context& dev_ctx,
right = splits[1].substr(1);
auto equation_for_A =
right + "," + ops[1] + "->" + gather_labels_except_reduction(ops[0]);
ops[1] + "," + right + "->" + gather_labels_except_reduction(ops[0]);
auto equation_for_B =
right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]);
auto operands_for_A = std::vector<const DenseTensor*>();
auto operands_for_B = std::vector<const DenseTensor*>();
DenseTensor dA, dB;
operands_for_A.push_back(&out_grad);
// dA = einsum(B, dC)
operands_for_A.push_back(x[1]);
operands_for_A.push_back(&out_grad);
// dB = einsum(dC, A)
operands_for_B.push_back(&out_grad);
operands_for_B.push_back(x[0]);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <set>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
......@@ -55,7 +56,8 @@ inline static void ValidationCheck(const std::string& equation) {
enum LabelType {
ALL_TYPE = 0,
Batch = 1, // ABO
Free, // AO, BO
AO, // AO -- free label
BO, // BO -- free label
Contraction, // AB
Reduction, // A, B
};
......@@ -125,18 +127,32 @@ inline std::vector<char> union_labels(const std::vector<char>& a,
return res;
}
// Apply transforms to all_labels and get another all_labels
inline std::vector<char> TransformLabelsOrder(
const std::vector<char>& all_labels,
const LabelMap& type,
std::vector<LabelType> new_order) {
std::vector<char> ret;
for (auto cnt_type : new_order) {
std::vector<char> tmp;
for (int c : all_labels) {
if (type[c] == cnt_type) tmp.push_back(c);
std::sort(tmp.begin(), tmp.end());
}
ret.insert(ret.end(), tmp.begin(), tmp.end());
}
return ret;
}
inline static void GlobalInfo(const std::vector<std::string>& op_labels,
const std::string& right,
LabelMap* label2type,
std::vector<char>* sorted_labels) {
// sorted_labels: ['.', <right>, <left only label>]
VLOG(5) << "GlobalInfo: "
<< paddle::string::join_strings(*sorted_labels, ",");
std::vector<char> all;
LabelMap counter(0);
for (auto& ch : right) { // char
int c = ch;
(*label2type)[c] = LabelType::Free;
(*label2type)[c] = LabelType::BO;
}
for (auto& op : op_labels) {
......@@ -146,39 +162,36 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
all.push_back(ch);
}
counter[c] += 1;
if ((*label2type)[c] != LabelType::Free && counter[c] == 2)
if ((*label2type)[c] != LabelType::BO && counter[c] == 2)
(*label2type)[c] = LabelType::Contraction;
else if (counter[c] == 2)
(*label2type)[c] = LabelType::Batch;
}
}
// BO is represent Free, so we need find the AO.
for (int c : op_labels[0]) {
if ((*label2type)[c] == LabelType::BO) (*label2type)[c] = LabelType::AO;
}
(*label2type)['.'] = LabelType::Batch;
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Batch)
sorted_labels->push_back(static_cast<char>(c));
});
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Free)
sorted_labels->push_back(static_cast<char>(c));
});
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Contraction)
sorted_labels->push_back(static_cast<char>(c));
});
std::for_each(all.begin(), all.end(), [&sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Reduction)
sorted_labels->push_back(static_cast<char>(c));
});
VLOG(5) << "GlobalInfo: sorted_labels before: "
<< paddle::string::join_strings(*sorted_labels, ",");
*sorted_labels = TransformLabelsOrder(all,
*label2type,
{LabelType::Batch,
LabelType::AO,
LabelType::BO,
LabelType::Contraction,
LabelType::Reduction});
if (counter[static_cast<int>('.')] > 0) {
std::vector<char> tmp;
tmp.push_back('.');
// push '.' in the front
*sorted_labels = union_labels(tmp, *sorted_labels);
VLOG(5) << "GlobalInfo: sorted_labels after: "
<< paddle::string::join_strings(*sorted_labels, ",");
}
VLOG(5) << "GlobalInfo: sorted_labels after: "
<< paddle::string::join_strings(*sorted_labels, ",");
}
inline static void InferLabelShape(const std::vector<std::string>& op_labels,
......@@ -289,17 +302,20 @@ inline static void ParseEinsumEquation(
*right = results[1].substr(1);
ReplaceEllipsis(*right);
auto op_labels = paddle::string::split_string(left, ",");
// split_string("i,") -> ["i"], we expect 2 op_labels.
if (left[left.size() - 1] == ',') op_labels.push_back("");
std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis);
GlobalInfo(op_labels, *right, labeltype, all_labels);
InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims);
VLOG(5) << "Einsum Infershape: right:" << right;
VLOG(5) << "Einsum Infershape: op_labels:"
<< paddle::string::join_strings(op_labels, "\n");
VLOG(5) << "Einsum Infershape: right:" << *right;
VLOG(5) << "Einsum Infershape: left :"
<< paddle::string::join_strings(op_labels, '\n');
InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims);
for (size_t i = 0; i < inputs.size(); ++i) {
InferLabelPerm(
op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i]));
}
VLOG(5) << "Einsum Infershape: end";
}
template <typename T>
......@@ -327,10 +343,12 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels,
const LabelMap& perm,
const LabelMap& label2shape,
const std::vector<int>& ellipsis,
LabelType filter) {
std::set<LabelType> filter) {
std::vector<T> res;
for (T c : all_labels) {
if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) {
if ((filter.count(LabelType::ALL_TYPE) ||
filter.count(LabelType(type[c]))) &&
perm[c] != -1) {
if (c == '.')
res.insert(res.end(), ellipsis.begin(), ellipsis.end());
else
......@@ -390,7 +408,8 @@ DenseTensor PerformContraction(
const LabelMap& label2type,
const LabelMap& label2shape,
const std::vector<std::vector<int>>& ellipsis_dims,
const std::vector<int>& broadcast_dims) {
const std::vector<int>& broadcast_dims,
std::vector<DenseTensor*> cache) {
// Get All the Batches, so perm is
auto all_valid = LabelMap(1);
auto recover_dim = GetShapeByType<int>(all_labels,
......@@ -398,36 +417,74 @@ DenseTensor PerformContraction(
all_valid,
label2shape,
broadcast_dims,
LabelType::Batch);
{LabelType::Batch});
auto preprocess = [&](const DenseTensor& t,
const LabelMap& perm,
const std::vector<int>& ellipsis) -> DenseTensor {
auto frees = GetShapeByType<int>(
all_labels, label2type, perm, label2shape, ellipsis, LabelType::Free);
const std::vector<int>& ellipsis,
int operand_idx) -> DenseTensor {
// reshape
auto frees = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
{LabelType::AO, LabelType::BO});
auto conts = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
LabelType::Contraction);
auto trans_t = PerformTranspose<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
auto mul_dims = GetShapeByType<int>(
all_labels, label2type, perm, label2shape, ellipsis, LabelType::Batch);
{LabelType::Contraction});
std::vector<char> reordered_all_labels = all_labels;
if (operand_idx == 1) {
reordered_all_labels = TransformLabelsOrder(all_labels,
label2type,
{LabelType::Batch,
LabelType::Contraction,
LabelType::AO,
LabelType::BO,
LabelType::Reduction});
}
// reduction
DenseTensor trans_t;
if (cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
} else {
auto reduct_t = PerformReduction<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
{LabelType::Batch});
recover_dim.insert(recover_dim.end(), frees.begin(), frees.end());
mul_dims.push_back(
std::accumulate(frees.begin(), frees.end(), 1, std::multiplies<int>()));
mul_dims.push_back(
std::accumulate(conts.begin(), conts.end(), 1, std::multiplies<int>()));
if (operand_idx == 0) {
mul_dims.push_back(std::accumulate(
frees.begin(), frees.end(), 1, std::multiplies<int>()));
mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
} else {
mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
mul_dims.push_back(std::accumulate(
frees.begin(), frees.end(), 1, std::multiplies<int>()));
}
VLOG(5) << "PerformContraction: mul_dims: "
<< paddle::string::join_strings(mul_dims, ",");
trans_t.Resize(make_ddim(mul_dims));
return trans_t;
};
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0]);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1]);
// Reduction, Reshape and Matmul
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1);
auto after_contraction =
Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, true);
Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, false);
VLOG(5) << "PerformContraction: recover_dim: "
<< paddle::string::join_strings(recover_dim, ",");
after_contraction.Resize(make_ddim(recover_dim));
......@@ -465,10 +522,11 @@ void TransposeToOutput(const Context& dev_ctx,
}
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out) {
void EinsumKernelImpl(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
ValidationCheck(equation);
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
......@@ -498,22 +556,18 @@ void EinsumKernel(const Context& dev_ctx,
if (inputs.size() == 2) {
auto& A = inputs[0];
auto& B = inputs[1];
// Reduce Procedure
auto reduce_A = PerformReduction<T, Context>(
dev_ctx, *A, label2perms[0], all_labels, ellipsis_dims[0], labeltype);
auto reduce_B = PerformReduction<T, Context>(
dev_ctx, *B, label2perms[1], all_labels, ellipsis_dims[1], labeltype);
// Contract Procedure
// Reduction and Contract Procedure
dev_ctx.template Alloc<T>(out);
auto after_contraction = PerformContraction<T, Context>(dev_ctx,
reduce_A,
reduce_B,
*A,
*B,
label2perms,
all_labels,
labeltype,
labelshape,
ellipsis_dims,
broadcast_dims);
broadcast_dims,
cache);
TransposeToOutput<T, Context>(dev_ctx,
after_contraction,
right,
......@@ -545,4 +599,18 @@ void EinsumKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out) {
std::vector<DenseTensor> cache(inputs.size()); // set empty; TA, TB, TdC
std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC
for (size_t i = 0; i < inputs.size(); ++i) {
cache_tensor[i] = &cache[i];
}
EinsumKernelImpl<T, Context>(dev_ctx, inputs, equation, out, cache_tensor);
}
} // namespace phi
......@@ -464,5 +464,19 @@ class TestNumpyTests(unittest.TestCase):
self.check_output_equal(a, e)
class TestStaticGraphShape(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_shape(self):
A = paddle.static.data(name='x', shape=[-1])
B = paddle.static.data(name='y', shape=[384])
C = paddle.einsum('i,d->id', A, B)
self.assertEqual(C.shape, (-1, 384))
if __name__ == "__main__":
u
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册