未验证 提交 cf198dc9 编写于 作者: X xiongkun 提交者: GitHub

[EinsumOp] Polish forward logic and backward logic for optimize (#42603)

* change logic for optimize

* modifty
上级 02e5c4be
...@@ -148,14 +148,16 @@ void EinsumGradKernel(const Context& dev_ctx, ...@@ -148,14 +148,16 @@ void EinsumGradKernel(const Context& dev_ctx,
right = splits[1].substr(1); right = splits[1].substr(1);
auto equation_for_A = auto equation_for_A =
right + "," + ops[1] + "->" + gather_labels_except_reduction(ops[0]); ops[1] + "," + right + "->" + gather_labels_except_reduction(ops[0]);
auto equation_for_B = auto equation_for_B =
right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]); right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]);
auto operands_for_A = std::vector<const DenseTensor*>(); auto operands_for_A = std::vector<const DenseTensor*>();
auto operands_for_B = std::vector<const DenseTensor*>(); auto operands_for_B = std::vector<const DenseTensor*>();
DenseTensor dA, dB; DenseTensor dA, dB;
operands_for_A.push_back(&out_grad); // dA = einsum(B, dC)
operands_for_A.push_back(x[1]); operands_for_A.push_back(x[1]);
operands_for_A.push_back(&out_grad);
// dB = einsum(dC, A)
operands_for_B.push_back(&out_grad); operands_for_B.push_back(&out_grad);
operands_for_B.push_back(x[0]); operands_for_B.push_back(x[0]);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <set>
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h"
...@@ -55,7 +56,8 @@ inline static void ValidationCheck(const std::string& equation) { ...@@ -55,7 +56,8 @@ inline static void ValidationCheck(const std::string& equation) {
enum LabelType { enum LabelType {
ALL_TYPE = 0, ALL_TYPE = 0,
Batch = 1, // ABO Batch = 1, // ABO
Free, // AO, BO AO, // AO -- free label
BO, // BO -- free label
Contraction, // AB Contraction, // AB
Reduction, // A, B Reduction, // A, B
}; };
...@@ -125,18 +127,32 @@ inline std::vector<char> union_labels(const std::vector<char>& a, ...@@ -125,18 +127,32 @@ inline std::vector<char> union_labels(const std::vector<char>& a,
return res; return res;
} }
// Apply transforms to all_labels and get another all_labels
inline std::vector<char> TransformLabelsOrder(
const std::vector<char>& all_labels,
const LabelMap& type,
std::vector<LabelType> new_order) {
std::vector<char> ret;
for (auto cnt_type : new_order) {
std::vector<char> tmp;
for (int c : all_labels) {
if (type[c] == cnt_type) tmp.push_back(c);
std::sort(tmp.begin(), tmp.end());
}
ret.insert(ret.end(), tmp.begin(), tmp.end());
}
return ret;
}
inline static void GlobalInfo(const std::vector<std::string>& op_labels, inline static void GlobalInfo(const std::vector<std::string>& op_labels,
const std::string& right, const std::string& right,
LabelMap* label2type, LabelMap* label2type,
std::vector<char>* sorted_labels) { std::vector<char>* sorted_labels) {
// sorted_labels: ['.', <right>, <left only label>]
VLOG(5) << "GlobalInfo: "
<< paddle::string::join_strings(*sorted_labels, ",");
std::vector<char> all; std::vector<char> all;
LabelMap counter(0); LabelMap counter(0);
for (auto& ch : right) { // char for (auto& ch : right) { // char
int c = ch; int c = ch;
(*label2type)[c] = LabelType::Free; (*label2type)[c] = LabelType::BO;
} }
for (auto& op : op_labels) { for (auto& op : op_labels) {
...@@ -146,39 +162,36 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels, ...@@ -146,39 +162,36 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
all.push_back(ch); all.push_back(ch);
} }
counter[c] += 1; counter[c] += 1;
if ((*label2type)[c] != LabelType::Free && counter[c] == 2) if ((*label2type)[c] != LabelType::BO && counter[c] == 2)
(*label2type)[c] = LabelType::Contraction; (*label2type)[c] = LabelType::Contraction;
else if (counter[c] == 2) else if (counter[c] == 2)
(*label2type)[c] = LabelType::Batch; (*label2type)[c] = LabelType::Batch;
} }
} }
// BO is represent Free, so we need find the AO.
for (int c : op_labels[0]) {
if ((*label2type)[c] == LabelType::BO) (*label2type)[c] = LabelType::AO;
}
(*label2type)['.'] = LabelType::Batch; (*label2type)['.'] = LabelType::Batch;
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Batch) *sorted_labels = TransformLabelsOrder(all,
sorted_labels->push_back(static_cast<char>(c)); *label2type,
}); {LabelType::Batch,
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { LabelType::AO,
if ((*label2type)[c] == LabelType::Free) LabelType::BO,
sorted_labels->push_back(static_cast<char>(c)); LabelType::Contraction,
}); LabelType::Reduction});
std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Contraction)
sorted_labels->push_back(static_cast<char>(c));
});
std::for_each(all.begin(), all.end(), [&sorted_labels, label2type](int c) {
if ((*label2type)[c] == LabelType::Reduction)
sorted_labels->push_back(static_cast<char>(c));
});
VLOG(5) << "GlobalInfo: sorted_labels before: "
<< paddle::string::join_strings(*sorted_labels, ",");
if (counter[static_cast<int>('.')] > 0) { if (counter[static_cast<int>('.')] > 0) {
std::vector<char> tmp; std::vector<char> tmp;
tmp.push_back('.'); tmp.push_back('.');
// push '.' in the front // push '.' in the front
*sorted_labels = union_labels(tmp, *sorted_labels); *sorted_labels = union_labels(tmp, *sorted_labels);
VLOG(5) << "GlobalInfo: sorted_labels after: "
<< paddle::string::join_strings(*sorted_labels, ",");
} }
VLOG(5) << "GlobalInfo: sorted_labels after: "
<< paddle::string::join_strings(*sorted_labels, ",");
} }
inline static void InferLabelShape(const std::vector<std::string>& op_labels, inline static void InferLabelShape(const std::vector<std::string>& op_labels,
...@@ -289,17 +302,20 @@ inline static void ParseEinsumEquation( ...@@ -289,17 +302,20 @@ inline static void ParseEinsumEquation(
*right = results[1].substr(1); *right = results[1].substr(1);
ReplaceEllipsis(*right); ReplaceEllipsis(*right);
auto op_labels = paddle::string::split_string(left, ","); auto op_labels = paddle::string::split_string(left, ",");
// split_string("i,") -> ["i"], we expect 2 op_labels.
if (left[left.size() - 1] == ',') op_labels.push_back("");
std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis); std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis);
GlobalInfo(op_labels, *right, labeltype, all_labels); GlobalInfo(op_labels, *right, labeltype, all_labels);
InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims); InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims);
VLOG(5) << "Einsum Infershape: right:" << right; VLOG(5) << "Einsum Infershape: right:" << *right;
VLOG(5) << "Einsum Infershape: op_labels:" VLOG(5) << "Einsum Infershape: left :"
<< paddle::string::join_strings(op_labels, "\n"); << paddle::string::join_strings(op_labels, '\n');
InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims); InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims);
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
InferLabelPerm( InferLabelPerm(
op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i])); op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i]));
} }
VLOG(5) << "Einsum Infershape: end";
} }
template <typename T> template <typename T>
...@@ -327,10 +343,12 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels, ...@@ -327,10 +343,12 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels,
const LabelMap& perm, const LabelMap& perm,
const LabelMap& label2shape, const LabelMap& label2shape,
const std::vector<int>& ellipsis, const std::vector<int>& ellipsis,
LabelType filter) { std::set<LabelType> filter) {
std::vector<T> res; std::vector<T> res;
for (T c : all_labels) { for (T c : all_labels) {
if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) { if ((filter.count(LabelType::ALL_TYPE) ||
filter.count(LabelType(type[c]))) &&
perm[c] != -1) {
if (c == '.') if (c == '.')
res.insert(res.end(), ellipsis.begin(), ellipsis.end()); res.insert(res.end(), ellipsis.begin(), ellipsis.end());
else else
...@@ -390,7 +408,8 @@ DenseTensor PerformContraction( ...@@ -390,7 +408,8 @@ DenseTensor PerformContraction(
const LabelMap& label2type, const LabelMap& label2type,
const LabelMap& label2shape, const LabelMap& label2shape,
const std::vector<std::vector<int>>& ellipsis_dims, const std::vector<std::vector<int>>& ellipsis_dims,
const std::vector<int>& broadcast_dims) { const std::vector<int>& broadcast_dims,
std::vector<DenseTensor*> cache) {
// Get All the Batches, so perm is // Get All the Batches, so perm is
auto all_valid = LabelMap(1); auto all_valid = LabelMap(1);
auto recover_dim = GetShapeByType<int>(all_labels, auto recover_dim = GetShapeByType<int>(all_labels,
...@@ -398,36 +417,74 @@ DenseTensor PerformContraction( ...@@ -398,36 +417,74 @@ DenseTensor PerformContraction(
all_valid, all_valid,
label2shape, label2shape,
broadcast_dims, broadcast_dims,
LabelType::Batch); {LabelType::Batch});
auto preprocess = [&](const DenseTensor& t, auto preprocess = [&](const DenseTensor& t,
const LabelMap& perm, const LabelMap& perm,
const std::vector<int>& ellipsis) -> DenseTensor { const std::vector<int>& ellipsis,
auto frees = GetShapeByType<int>( int operand_idx) -> DenseTensor {
all_labels, label2type, perm, label2shape, ellipsis, LabelType::Free); // reshape
auto frees = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
{LabelType::AO, LabelType::BO});
auto conts = GetShapeByType<int>(all_labels, auto conts = GetShapeByType<int>(all_labels,
label2type, label2type,
perm, perm,
label2shape, label2shape,
ellipsis, ellipsis,
LabelType::Contraction); {LabelType::Contraction});
auto trans_t = PerformTranspose<T, Context>( std::vector<char> reordered_all_labels = all_labels;
dev_ctx, t, perm, all_labels, ellipsis, label2type); if (operand_idx == 1) {
auto mul_dims = GetShapeByType<int>( reordered_all_labels = TransformLabelsOrder(all_labels,
all_labels, label2type, perm, label2shape, ellipsis, LabelType::Batch); label2type,
{LabelType::Batch,
LabelType::Contraction,
LabelType::AO,
LabelType::BO,
LabelType::Reduction});
}
// reduction
DenseTensor trans_t;
if (cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
} else {
auto reduct_t = PerformReduction<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
label2type,
perm,
label2shape,
ellipsis,
{LabelType::Batch});
recover_dim.insert(recover_dim.end(), frees.begin(), frees.end()); recover_dim.insert(recover_dim.end(), frees.begin(), frees.end());
mul_dims.push_back( if (operand_idx == 0) {
std::accumulate(frees.begin(), frees.end(), 1, std::multiplies<int>())); mul_dims.push_back(std::accumulate(
mul_dims.push_back( frees.begin(), frees.end(), 1, std::multiplies<int>()));
std::accumulate(conts.begin(), conts.end(), 1, std::multiplies<int>())); mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
} else {
mul_dims.push_back(std::accumulate(
conts.begin(), conts.end(), 1, std::multiplies<int>()));
mul_dims.push_back(std::accumulate(
frees.begin(), frees.end(), 1, std::multiplies<int>()));
}
VLOG(5) << "PerformContraction: mul_dims: " VLOG(5) << "PerformContraction: mul_dims: "
<< paddle::string::join_strings(mul_dims, ","); << paddle::string::join_strings(mul_dims, ",");
trans_t.Resize(make_ddim(mul_dims)); trans_t.Resize(make_ddim(mul_dims));
return trans_t; return trans_t;
}; };
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0]);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1]); // Reduction, Reshape and Matmul
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1);
auto after_contraction = auto after_contraction =
Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, true); Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, false);
VLOG(5) << "PerformContraction: recover_dim: " VLOG(5) << "PerformContraction: recover_dim: "
<< paddle::string::join_strings(recover_dim, ","); << paddle::string::join_strings(recover_dim, ",");
after_contraction.Resize(make_ddim(recover_dim)); after_contraction.Resize(make_ddim(recover_dim));
...@@ -465,10 +522,11 @@ void TransposeToOutput(const Context& dev_ctx, ...@@ -465,10 +522,11 @@ void TransposeToOutput(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx, void EinsumKernelImpl(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs, const std::vector<const DenseTensor*>& inputs,
const std::string& equation, const std::string& equation,
DenseTensor* out) { DenseTensor* out,
std::vector<DenseTensor*> cache) {
ValidationCheck(equation); ValidationCheck(equation);
// collect the following informations to prepare einsum. // collect the following informations to prepare einsum.
LabelMap labelshape(0); LabelMap labelshape(0);
...@@ -498,22 +556,18 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -498,22 +556,18 @@ void EinsumKernel(const Context& dev_ctx,
if (inputs.size() == 2) { if (inputs.size() == 2) {
auto& A = inputs[0]; auto& A = inputs[0];
auto& B = inputs[1]; auto& B = inputs[1];
// Reduce Procedure // Reduction and Contract Procedure
auto reduce_A = PerformReduction<T, Context>(
dev_ctx, *A, label2perms[0], all_labels, ellipsis_dims[0], labeltype);
auto reduce_B = PerformReduction<T, Context>(
dev_ctx, *B, label2perms[1], all_labels, ellipsis_dims[1], labeltype);
// Contract Procedure
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
auto after_contraction = PerformContraction<T, Context>(dev_ctx, auto after_contraction = PerformContraction<T, Context>(dev_ctx,
reduce_A, *A,
reduce_B, *B,
label2perms, label2perms,
all_labels, all_labels,
labeltype, labeltype,
labelshape, labelshape,
ellipsis_dims, ellipsis_dims,
broadcast_dims); broadcast_dims,
cache);
TransposeToOutput<T, Context>(dev_ctx, TransposeToOutput<T, Context>(dev_ctx,
after_contraction, after_contraction,
right, right,
...@@ -545,4 +599,18 @@ void EinsumKernel(const Context& dev_ctx, ...@@ -545,4 +599,18 @@ void EinsumKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void EinsumKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out) {
std::vector<DenseTensor> cache(inputs.size()); // set empty; TA, TB, TdC
std::vector<DenseTensor*> cache_tensor(
inputs.size()); // set empty; TA, TB, TdC
for (size_t i = 0; i < inputs.size(); ++i) {
cache_tensor[i] = &cache[i];
}
EinsumKernelImpl<T, Context>(dev_ctx, inputs, equation, out, cache_tensor);
}
} // namespace phi } // namespace phi
...@@ -464,5 +464,19 @@ class TestNumpyTests(unittest.TestCase): ...@@ -464,5 +464,19 @@ class TestNumpyTests(unittest.TestCase):
self.check_output_equal(a, e) self.check_output_equal(a, e)
class TestStaticGraphShape(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_shape(self):
A = paddle.static.data(name='x', shape=[-1])
B = paddle.static.data(name='y', shape=[384])
C = paddle.einsum('i,d->id', A, B)
self.assertEqual(C.shape, (-1, 384))
if __name__ == "__main__": if __name__ == "__main__":
u unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册