From 6e1c14e357e4dc88a6c484cc79529aff3d8911c7 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 31 Oct 2022 17:45:56 +0800 Subject: [PATCH] [Einsum] Einsum support repeated labels. (#47290) * add unittest for einsum-v2-trace and diagonal * repeat labels. * einsum support repeated labels. * forward is ok for diagonal and undiagonalized. TODO: check backward is ok by our theorem. * backward is ok! * fix by PR suggestions. * fix ci error * fix ci error * fix ci warning --- paddle/phi/infermeta/unary.cc | 4 +- .../phi/kernels/cpu/diagonal_grad_kernel.cc | 5 +- paddle/phi/kernels/cpu/diagonal_kernel.cc | 2 + paddle/phi/kernels/diagonal_kernel.h | 14 + .../phi/kernels/fill_diagonal_tensor_kernel.h | 16 + .../phi/kernels/gpu/diagonal_grad_kernel.cu | 7 +- paddle/phi/kernels/gpu/diagonal_kernel.cu | 6 +- .../gpu/fill_diagonal_tensor_grad_kernel.cu | 1 + .../gpu/fill_diagonal_tensor_kernel.cu | 1 + paddle/phi/kernels/impl/einsum_grad_impl.h | 27 +- paddle/phi/kernels/impl/einsum_impl.h | 283 ++++++++++++------ .../fluid/tests/unittests/test_einsum_op.py | 49 +++ .../fluid/tests/unittests/test_einsum_v2.py | 128 ++++---- python/paddle/tensor/einsum.py | 8 - 14 files changed, 399 insertions(+), 152 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 150da6d59b..1e4c226a9a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -835,6 +835,7 @@ void EinsumInferMeta(const std::vector& inputs, for (auto& i : inputs) { input_dims.push_back(i->dims()); } + std::vector input_strs; std::string right; ParseEinsumEquation(equation, input_dims, @@ -845,7 +846,8 @@ void EinsumInferMeta(const std::vector& inputs, &ellipsis_dims, &broadcast_dims, &output_dims, - &right); + &right, + &input_strs); VLOG(3) << "Einsum Infershape: input dims:" << paddle::string::join_strings(input_dims, "\n"); diff --git a/paddle/phi/kernels/cpu/diagonal_grad_kernel.cc b/paddle/phi/kernels/cpu/diagonal_grad_kernel.cc index 5671e70c96..f5d6ee2dce 100644 --- a/paddle/phi/kernels/cpu/diagonal_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/diagonal_grad_kernel.cc @@ -90,4 +90,7 @@ PD_REGISTER_KERNEL(diagonal_grad, float, double, int, - int64_t) {} + int64_t, + bool, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/diagonal_kernel.cc b/paddle/phi/kernels/cpu/diagonal_kernel.cc index 8ea5826ba2..f125802c19 100644 --- a/paddle/phi/kernels/cpu/diagonal_kernel.cc +++ b/paddle/phi/kernels/cpu/diagonal_kernel.cc @@ -88,4 +88,6 @@ PD_REGISTER_KERNEL(diagonal, double, int, int64_t, + phi::dtype::complex, + phi::dtype::complex, bool) {} diff --git a/paddle/phi/kernels/diagonal_kernel.h b/paddle/phi/kernels/diagonal_kernel.h index 2d866d4e30..fc8844edc9 100644 --- a/paddle/phi/kernels/diagonal_kernel.h +++ b/paddle/phi/kernels/diagonal_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -39,4 +40,17 @@ void DiagonalKernel(const Context& dev_ctx, int axis1, int axis2, DenseTensor* out); + +template +DenseTensor Diagonal(const Context& dev_ctx, + const DenseTensor& x, + int offset, + int axis1, + int axis2) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + DiagonalInferMeta(x, offset, axis1, axis2, &meta_out); + DiagonalKernel(dev_ctx, x, offset, axis1, axis2, &dense_out); + return dense_out; +} } // namespace phi diff --git a/paddle/phi/kernels/fill_diagonal_tensor_kernel.h b/paddle/phi/kernels/fill_diagonal_tensor_kernel.h index 9d6c8da93e..c3fe394a7f 100644 --- a/paddle/phi/kernels/fill_diagonal_tensor_kernel.h +++ b/paddle/phi/kernels/fill_diagonal_tensor_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" namespace phi { @@ -27,6 +28,21 @@ void FillDiagonalTensorKernel(const Context& ctx, int dim2, DenseTensor* out); +template +DenseTensor FillDiagonalTensor(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int64_t offset, + int dim1, + int dim2) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + FillDiagonalTensorInferMeta(x, y, offset, dim1, dim2, &meta_out); + FillDiagonalTensorKernel( + ctx, x, y, offset, dim1, dim2, &dense_out); + return dense_out; +} + void CalMatDims(phi::DDim out_dims, int dim1, int dim2, diff --git a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu index a5c0e05959..1fd1e44699 100644 --- a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu @@ -166,4 +166,9 @@ PD_REGISTER_KERNEL(diagonal_grad, float, double, int, - int64_t) {} + int64_t, + bool, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/diagonal_kernel.cu b/paddle/phi/kernels/gpu/diagonal_kernel.cu index 2e4ae59199..169cb3f2c7 100644 --- a/paddle/phi/kernels/gpu/diagonal_kernel.cu +++ b/paddle/phi/kernels/gpu/diagonal_kernel.cu @@ -163,4 +163,8 @@ PD_REGISTER_KERNEL(diagonal, double, int, int64_t, - bool) {} + bool, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu index 0e302b23ee..04f03e3aae 100644 --- a/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_diagonal_tensor_grad_kernel.cu @@ -109,6 +109,7 @@ PD_REGISTER_KERNEL(fill_diagonal_tensor_grad, int8_t, uint8_t, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex, bool) {} diff --git a/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu b/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu index 739a8666e3..33c06e339b 100644 --- a/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu +++ b/paddle/phi/kernels/gpu/fill_diagonal_tensor_kernel.cu @@ -131,6 +131,7 @@ PD_REGISTER_KERNEL(fill_diagonal_tensor, int8_t, uint8_t, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex, bool) {} diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index 992b7572c1..bf27f3ef2b 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -20,15 +20,20 @@ #include "paddle/utils/string/string_helper.h" namespace phi { + template DenseTensor PerformTileAndReduction(const Context& dev_ctx, const LabelMap& label2type, const LabelMap& label2shape, const std::vector& broadcast_dims, const std::vector& ellipsis_dims, - std::string op_label, // value pass - DenseTensor& t) { // NOLINT - ReplaceEllipsis(op_label); + std::string equ, // value pass + DenseTensor& t) { // NOLINT + auto tmp_label = equ; + ReplaceEllipsis(tmp_label); + auto tmp_union = unique_labels(tmp_label); + auto op_label = std::string(tmp_union.begin(), tmp_union.end()); + VLOG(5) << "Start PerformTileAndReduction" << equ; DenseTensor ret; std::vector repeat_times; std::vector resize_dims; @@ -61,6 +66,8 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, })) { after_tile = t; } else { + VLOG(4) << "do TileKernel with repeat_times=" + << paddle::string::join_strings(repeat_times, ","); TileKernel(dev_ctx, t, repeat_times, &after_tile); } size_t n_ellipsis_idx = op_label.find(".", 0); @@ -92,7 +99,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, VLOG(5) << "PermformTileAndReduction: recover shape: " << paddle::string::join_strings(recover_shape, ","); ret.Resize(make_ddim(recover_shape)); - return ret; + // undiagonalize by einsum equation. only contain undiagonal operations. + DenseTensor out; + VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ; + EinsumKernel(dev_ctx, {&ret}, op_label + "->" + equ, &out); + return out; } template @@ -115,6 +126,7 @@ void EinsumGradKernel(const Context& dev_ctx, for (auto& i : x) { input_dims.push_back(i->dims()); } + std::vector input_strs; std::string right; ParseEinsumEquation(equation, input_dims, @@ -125,13 +137,15 @@ void EinsumGradKernel(const Context& dev_ctx, &ellipsis_dims, &broadcast_dims, &output_dims, - &right); + &right, + &input_strs); auto gather_labels_except_reduction = [&labeltype](std::string all) { std::string res(""); for (auto c : all) if (labeltype[static_cast(c)] != LabelType::Reduction) res += c; - return res; + auto tmp_unique = unique_labels(res); + return std::string(tmp_unique.begin(), tmp_unique.end()); }; if (x.size() == 1) { // Unary auto splits = paddle::string::split_string(equation, "->"); @@ -141,6 +155,7 @@ void EinsumGradKernel(const Context& dev_ctx, auto new_operands = std::vector(); new_operands.push_back(&out_grad); DenseTensor before_tile; + VLOG(5) << "new_equation is " << new_equation; EinsumKernel(dev_ctx, new_operands, new_equation, &before_tile); *(x_grad[0]) = PerformTileAndReduction(dev_ctx, labeltype, diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index dafb967ae8..392949e065 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -18,6 +18,9 @@ #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/diagonal_kernel.h" +#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" @@ -89,6 +92,9 @@ class LabelMap { if (label == '.') i = N - 1; return map[i]; } + bool exist(char label) { return !is_default(label); } + + private: // non-exist is present by is_default bool is_default(char label) { return (*this)[static_cast(label)] == default_value; @@ -117,8 +123,9 @@ inline static void ReplaceEllipsis(std::string& s) { // NOLINT } } -inline std::vector union_labels(const std::vector& a, - const std::vector& b) { +template +inline std::vector union_labels(const CharIterable1& a, + const CharIterable2& b) { LabelMap counter(0); std::vector res; auto f = [&](char c) { @@ -132,6 +139,11 @@ inline std::vector union_labels(const std::vector& a, return res; } +template +inline std::vector unique_labels(const CharIterable& a) { + return union_labels(a, CharIterable()); +} + // Apply transforms to all_labels and get another all_labels inline std::vector TransformLabelsOrder( const std::vector& all_labels, @@ -160,9 +172,9 @@ inline static void GlobalInfo(const std::vector& op_labels, } for (auto& op : op_labels) { - for (auto& ch : op) { // char + for (auto& ch : unique_labels(op)) { // char int c = ch; - if (counter.is_default(c)) { + if (!counter.exist(c)) { all.push_back(ch); } counter[c] += 1; @@ -238,7 +250,7 @@ inline static void InferLabelShape(const std::vector& op_labels, v = op_dim[dim_ptr]; dim_ptr++; } - } else if (labelshape->is_default(c) || (*labelshape)[c] == -1) { + } else if (!labelshape->exist(c) || (*labelshape)[c] == -1) { (*labelshape)[c] = op_dim[dim_ptr]; dim_ptr++; } else if (op_dim[dim_ptr] != -1) { @@ -270,12 +282,15 @@ inline static void InferLabelShape(const std::vector& op_labels, << paddle::string::join_strings(*broadcast_dims, ","); } -inline static void InferLabelPerm(const std::string& op, +template +inline static void InferLabelPerm(const CharIterable& op, int n_broadcast, LabelMap* label2perm) { int cur = 0; for (int c : op) { - (*label2perm)[c] = cur; + if (!label2perm->exist( + c)) // can appear repeatly. we just record the first position. + (*label2perm)[c] = cur; if (c == '.') { cur += n_broadcast; } else { @@ -308,15 +323,21 @@ inline static void ParseEinsumEquation( std::vector>* ellipsis_dims, std::vector* broadcast_dims, std::vector* output_dims, - std::string* right) { + std::string* right, + std::vector* input_strs) { + VLOG(5) << "Start ParseEinsumEquation"; auto results = paddle::string::split_string(equation, "->"); auto left = results[0]; ReplaceEllipsis(left); *right = results[1].substr(1); ReplaceEllipsis(*right); auto op_labels = paddle::string::split_string(left, ","); - // split_string("i,") -> ["i"], we expect 2 op_labels. - if (left[left.size() - 1] == ',') op_labels.push_back(""); + // split_string("i,") -> ["i"], we push back a "". + // split_string("->") -> [], we push back a "". + if (op_labels.size() == 0) + op_labels.push_back(""); + else if (left[left.size() - 1] == ',') + op_labels.push_back(""); std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis); GlobalInfo(op_labels, *right, labeltype, all_labels); InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims); @@ -327,8 +348,8 @@ inline static void ParseEinsumEquation( for (size_t i = 0; i < inputs.size(); ++i) { InferLabelPerm( op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i])); + (*input_strs).push_back(std::move(op_labels[i])); } - VLOG(5) << "Einsum Infershape: end"; } template @@ -371,20 +392,124 @@ std::vector GetShapeByType(const std::vector& all_labels, return res; } +inline static std::vector perm_moveto(int n, int from, int to) { + // a permution means moving `from` to `to`. + /* + f => t permtation + -------------------- + 0 1 2 3 4 5 + 5 => 2 : 0 2 5 2 3 4 + 2 => 5 : 0 1 3 4 5 2 + we can conclude the following rules. + */ + if (from < 0) from = n + from; + if (to < 0) to = n + to; + std::vector res(n); + for (int i = 0; i < n; ++i) { + res[i] = i; + } + res[to] = from; + auto offset = from > to ? -1 : 1; + auto start = from > to ? to + 1 : from; + auto end = from > to ? from : to - 1; + for (int i = start; i <= end; ++i) { + res[i] += offset; + } + return res; +} + template -DenseTensor PerformReduction(const Context& dev_ctx, - const DenseTensor& tensor, - const LabelMap& label2perm, - const std::vector& all_labels, - const std::vector& ellipsis, - const LabelMap& label2type) { +DenseTensor Undiagonal(const Context& dev_ctx, + const DenseTensor& tensor, + size_t insert_pos, + size_t axis) { + // tensor with shape (3, 4, 5, 2, 1), insert_pos = 5, axis = 2. + // output is (3, 4, 5, 2, 1, 5) + VLOG(5) << "Start undiagonal with args: insert_pos = " << insert_pos + << ", axis = " << axis; + std::vector shape(tensor.dims().size() + 1); + int point = 0; // point to the tensor.dims() + for (size_t i = 0; i < shape.size(); ++i) { + if (i == insert_pos) + shape[i] = tensor.dims()[axis]; + else + shape[i] = tensor.dims()[point++]; + } + auto zeros = Full(dev_ctx, shape, 0); + auto diags = Transpose( + dev_ctx, tensor, perm_moveto(tensor.dims().size(), axis, -1)); + return FillDiagonalTensor( + dev_ctx, zeros, diags, 0, insert_pos, axis + (insert_pos <= axis)); +} + +template +DenseTensor PerformUndiagonal(const Context& dev_ctx, + const DenseTensor& tensor, + int n_broadcast, + const std::string& equ) { + // if the equ is 'iijjkij', then the tensor must be 'ijk', so we have enough + // information to do un-diagonal with equ. + auto res = tensor; + LabelMap label2perm(-1); + InferLabelPerm(equ, n_broadcast, &label2perm); + // Un-Diagonal + int tot = + equ.size() + n_broadcast + (equ.find(".") != std::string::npos ? -1 : 0); + int cur = tot - 1; + for (auto it = equ.rbegin(); it != equ.rend(); ++it) { + char c = *it; + if (c == '.') { + cur -= n_broadcast; + } else { + if (cur != label2perm[c]) { + // do diagonal, followed by movedim(). + auto insert_pos = cur - tot + res.dims().size() + 1; + res = Undiagonal(dev_ctx, res, insert_pos, label2perm[c]); + } + --cur; + } + } + return res; +} + +template +DenseTensor PerformDiagonalAndReduction(const Context& dev_ctx, + const DenseTensor& tensor, + const std::string& equ, + const LabelMap& label2perm, + const std::vector& all_labels, + const std::vector& ellipsis, + const LabelMap& label2type) { + auto res = tensor; + // Diagonal + int tot = equ.size() + ellipsis.size() + + (equ.find(".") != std::string::npos ? -1 : 0); + int cur = tot - 1; + for (auto it = equ.rbegin(); it != equ.rend(); ++it) { + char c = *it; + if (c == '.') { + cur -= ellipsis.size(); + } else { + if (cur != label2perm[c]) { + // do diagonal, followed by movedim(). + VLOG(5) << "Do diagonal with shape=" + << paddle::string::join_strings(vectorize(res.dims()), ',') + << ", axis1=" << cur << ", axis2=" << label2perm[c]; + res = Diagonal(dev_ctx, res, 0, cur, label2perm[c]); + res = Transpose( + dev_ctx, res, perm_moveto(res.dims().size(), -1, label2perm[c])); + } + --cur; + } + } + // reduction auto indices = GetLabelIndexByType( all_labels, label2type, label2perm, ellipsis, LabelType::Reduction); - VLOG(5) << "call PerformReduction: with axis: " + VLOG(5) << "call PerformDiagonalAndReduction: with axis: " << paddle::string::join_strings(indices, ","); - if (indices.size() == 0) return tensor; + if (indices.size() == 0) return res; return Sum( - dev_ctx, tensor, phi::IntArray(indices), tensor.dtype(), true); + dev_ctx, res, phi::IntArray(indices), res.dtype(), true); } inline bool is_no_need_transpose(const std::vector& axis) { @@ -415,8 +540,8 @@ DenseTensor PerformTranspose(const Context& dev_ctx, template DenseTensor PerformContraction( const Context& dev_ctx, - const DenseTensor& A, - const DenseTensor& B, + const std::vector& operands, + const std::vector& input_strs, const std::vector& label2perm, const std::vector& all_labels, const LabelMap& label2type, @@ -467,8 +592,14 @@ DenseTensor PerformContraction( trans_t.ShareBufferWith(*(cache[operand_idx])); VLOG(5) << "Cache Used!"; } else { - auto reduct_t = PerformReduction( - dev_ctx, t, perm, all_labels, ellipsis, label2type); + auto reduct_t = + PerformDiagonalAndReduction(dev_ctx, + t, + input_strs[operand_idx], + perm, + all_labels, + ellipsis, + label2type); trans_t = PerformTranspose( dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type); if (cache[operand_idx] != nullptr) @@ -499,10 +630,19 @@ DenseTensor PerformContraction( }; // Reduction, Reshape and Matmul - auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0); - auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1); - auto after_contraction = - Matmul(dev_ctx, trans_a, trans_b, false, false); + DenseTensor after_contraction; + if (operands.size() == 2) { + auto trans_a = + preprocess(*(operands[0]), label2perm[0], ellipsis_dims[0], 0); + auto trans_b = + preprocess(*(operands[1]), label2perm[1], ellipsis_dims[1], 1); + after_contraction = + Matmul(dev_ctx, trans_a, trans_b, false, false); + } else if (operands.size() == 1) { + after_contraction = + preprocess(*(operands[0]), label2perm[0], ellipsis_dims[0], 0); + } + if (recover_dim.size() == 0) recover_dim.push_back(1); VLOG(5) << "PerformContraction: recover_dim: " << paddle::string::join_strings(recover_dim, ","); after_contraction.Resize(make_ddim(recover_dim)); @@ -510,12 +650,11 @@ DenseTensor PerformContraction( } template -void TransposeToOutput(const Context& dev_ctx, - const DenseTensor& to_trans, - const std::string& right, - const std::vector& all_labels, - int n_broadcast_dims, - DenseTensor* output) { +DenseTensor TransposeToOutput(const Context& dev_ctx, + const DenseTensor& to_trans, + const std::vector& right, + const std::vector& all_labels, + int n_broadcast_dims) { std::vector axis; int offset = 0; if (std::find(all_labels.begin(), all_labels.end(), '.') != @@ -534,12 +673,11 @@ void TransposeToOutput(const Context& dev_ctx, } } if (is_no_need_transpose(axis)) { - output->ShareBufferWith(to_trans); - return; + return to_trans; } VLOG(5) << "call TransposeToOutput: with axis: " << paddle::string::join_strings(axis, ","); - TransposeKernel(dev_ctx, to_trans, axis, output); + return Transpose(dev_ctx, to_trans, axis); } template @@ -550,6 +688,7 @@ void EinsumKernelImpl(const Context& dev_ctx, DenseTensor* out, std::vector cache, bool is_forward = true) { + VLOG(5) << "Start EinsumKernelImpl"; ValidationCheck(equation); // collect the following informations to prepare einsum. LabelMap labelshape(0); @@ -564,6 +703,7 @@ void EinsumKernelImpl(const Context& dev_ctx, for (auto& i : inputs) { input_dims.push_back(i->dims()); } + std::vector input_strs; std::string right; if (!is_forward) { all_labels = forward_all_labels; @@ -577,57 +717,32 @@ void EinsumKernelImpl(const Context& dev_ctx, &ellipsis_dims, &broadcast_dims, &output_dims, - &right); - out->Resize(make_ddim(output_dims)); - if (inputs.size() == 2) { - auto& A = inputs[0]; - auto& B = inputs[1]; - // Reduction and Contract Procedure - auto after_contraction = PerformContraction(dev_ctx, - *A, - *B, - label2perms, - all_labels, - labeltype, - labelshape, - ellipsis_dims, - broadcast_dims, - cache, - !is_forward); - TransposeToOutput(dev_ctx, - after_contraction, - right, - all_labels, - broadcast_dims.size(), - out); - // Reshape Procedure - } else if (inputs.size() == 1) { - if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if - // loading the program from v2.3.0 - (*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward - // we can only see cached tensor. - } - auto reduce_A = PerformReduction(dev_ctx, - *inputs[0], - label2perms[0], - all_labels, - ellipsis_dims[0], - labeltype); - std::vector right_labels; - for (auto c : right) right_labels.push_back(c); - right_labels = union_labels(right_labels, all_labels); - *out = PerformTranspose(dev_ctx, - reduce_A, - label2perms[0], - right_labels, - broadcast_dims, - labeltype); - out->Resize(make_ddim(output_dims)); - } else { + &right, + &input_strs); + if (inputs.size() > 2) { PADDLE_THROW(phi::errors::InvalidArgument( "EinsumOp kernel only support len(operands) between (0, 2]. Use " "opt_einsum first to convert multi-variable to binary-variable.")); } + auto after_contraction = PerformContraction(dev_ctx, + inputs, + input_strs, + label2perms, + all_labels, + labeltype, + labelshape, + ellipsis_dims, + broadcast_dims, + cache, + !is_forward); + *out = TransposeToOutput(dev_ctx, + after_contraction, + unique_labels(right), + all_labels, + broadcast_dims.size()); + *out = PerformUndiagonal( + dev_ctx, *out, broadcast_dims.size(), right); + out->Resize(make_ddim(output_dims)); } template diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py index 9db367a233..bb48cd31dd 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -154,5 +154,54 @@ class TestEinsumWithBroadcast6(TestEinsumBinary): self.equation = "i,i->" +class TestEinsumWithDiagonal(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 10)] + self.types = [np.float64] + self.equation = "ii->" + + +class TestEinsumWithDiagonal2(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 3, 10)] + self.types = [np.float64] + self.equation = "iji->j" + + +class TestEinsumWithDiagonal3(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(5, 3, 2, 1, 4, 5)] + self.types = [np.float64] + self.equation = "a...a->..." + + +class TestEinsumWithDiagonal4(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(5, 3, 2, 1, 4, 5)] + self.types = [np.float64] + self.equation = "a...a->a..." + + +class TestEinsumWithDiagonal5(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(8, 8, 8)] + self.types = [np.float64] + self.equation = "aaa->a" + + +class TestEinsumWithDiagonal6(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(3, 5, 7, 3), (5, 7, 5, 7)] + self.types = [np.float64, np.float64] + self.equation = "ijki,jkjk->ik" + + +class TestEinsumWithDiagonal8(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(3, 5, 7, 3), (5, 7, 5, 7)] + self.types = [np.float64, np.float64] + self.equation = "ijki,jkjk->" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_einsum_v2.py b/python/paddle/fluid/tests/unittests/test_einsum_v2.py index e7b041124c..c7d2f9c76b 100644 --- a/python/paddle/fluid/tests/unittests/test_einsum_v2.py +++ b/python/paddle/fluid/tests/unittests/test_einsum_v2.py @@ -41,22 +41,6 @@ class TestErrors(unittest.TestCase): def setUp(self): pass - def test_diagonalize_errors(self): - a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') - a = paddle.to_tensor(a) - with self.assertRaisesRegex( - AssertionError, ('Duplicate labels are not supported.') - ): - paddle.einsum('...ii->...i', a) - with self.assertRaisesRegex( - AssertionError, ('Duplicate labels are not supported.') - ): - paddle.einsum('i...i', a) - with self.assertRaisesRegex( - AssertionError, ('Duplicate labels are not supported.') - ): - paddle.einsum('i...i->i...', a) - def test_param_errors(self): a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float') a = paddle.to_tensor(a) @@ -126,11 +110,6 @@ class TestErrors(unittest.TestCase): ("Invalid equation: missing ellipsis in output labels."), ): paddle.einsum('i...->i', a) - with self.assertRaisesRegex( - AssertionError, - ("Invalid equation: duplicate output labels are found."), - ): - paddle.einsum('i...->i...i', a) with self.assertRaisesRegex( AssertionError, ( @@ -162,6 +141,13 @@ class TestEinsum(unittest.TestCase): "I": np.random.rand(2, 2), "J": np.random.rand(1, 3, 5), "K": np.random.rand(1, 2, 3, 4), + "X": np.random.rand(5, 5), + "L": np.random.rand(5, 10, 5), + "M": np.random.rand(5, 3, 2, 1, 4, 5), + "N": np.random.rand(5, 5, 5), + "O": np.random.rand(3, 5, 7, 3), + "P": np.random.rand(5, 7, 5, 7), + "S": np.random.rand(4, 3, 4, 4), } def _get_place(self, force_to_use_cpu=False): @@ -207,14 +193,54 @@ class TestEinsum(unittest.TestCase): self.check_output_equal(result.numpy(), expected_result) -class TestEinsumVectorDot(TestEinsum): +class TestEinsumTraceDiag1(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ii->", "data": ["X"]} + + +class TestEinsumTraceDiag2(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "iji->j", "data": ["L"]} + + +class TestEinsumTraceDiag3(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "a...a->...", "data": ["M"]} + + +class TestEinsumTraceDiag4(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "a...a->a...", "data": ["M"]} + + +class TestEinsumTraceDiag5(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "aaa->a", "data": ["N"]} + + +# Numpy don't support i->ii, but paddle.einsum support. +# class TestEinsumTraceDiag6(TestEinsum): +# def setUp(self): +# self.sample = {"paradigm": "i->iii", "data": ["x"]} + +# class TestEinsumTraceDiag7(TestEinsum): +# def setUp(self): +# self.sample = {"paradigm": "i...->i...i", "data": ["S"]} + + +class TestEinsumTraceDiag2Ops(TestEinsum): + def setUp(self): + self.sample = {"paradigm": "ijki,jkjk->ik", "data": ["O", "P"]} + + +class TestEinsumIdentity(TestEinsum): def setUp(self): - self.sample = {"paradigm": "i,i->", "data": ["x", "x"]} + self.sample = {"paradigm": "...->...", "data": ["N"]} -class TestEinsumVectorMul(TestEinsum): +class TestEinsumElementwiseProduct(TestEinsum): def setUp(self): - self.sample = {"paradigm": "i,i->i", "data": ["x", "x"]} + self.sample = {"paradigm": "...,...->...", "data": ["N", "N"]} class TestEinsumVectorOuter(TestEinsum): @@ -436,37 +462,12 @@ class TestNumpyTests(unittest.TestCase): self.check_output("...,...", a, a) self.check_output("i,i", a, a) - # TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. - # p = np.ones((10, 2)).astype('float') - # q = np.ones((1, 2)).astype('float') - # self.check_output('ij,ij->j', p, q) - - # TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. - # x = np.array([2., 3.]).astype('float') - # y = np.array([4.]).astype('float') - # self.check_output("i, i", x, y) - - # TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. - # p = np.ones((1, 5)) / 2 - # q = np.ones((5, 5)) / 2 - # self.check_output("...ij,...jk->...ik", p, p) - # self.check_output("...ij,...jk->...ik", p, q) - x = np.eye(2).astype('float') y = np.ones(2).astype('float') self.check_output("ji,i->", x, y) self.check_output("i,ij->", y, x) self.check_output("ij,i->", x, y) - def test_large_nops(self): - pass - # TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this. - # a = np.arange(4 * 3 * 1 * 4).reshape(4, 3, 1, 4).astype('float') - # self.check_output('a...b,b...c,c...d', a, a, a) - # self.check_output('a...b,b...c,c...a', a, a, a) - # self.check_output('a...b,b...c,c...a', a, a, a) - # self.check_output('...ab,...ba,...ab,...ab', a, a, a, a) - def test_static_graph(self): paddle.enable_static() fluid = paddle.fluid @@ -569,5 +570,32 @@ class TestComplex(unittest.TestCase): c = paddle.einsum('xy,yz->xz', a, b) +class TestSimpleUndiagonal(unittest.TestCase): + """ + EinsumOp support undiagonalize. + """ + + def test_shape(self): + paddle.disable_static() + A = paddle.to_tensor(np.array([1.0, 2.0])) + A_expect = paddle.to_tensor([[1.0, 0.0], [0.0, 2.0]]) + A_actual = paddle.einsum('i->ii', A) + np.array_equal(A_expect.numpy(), A_actual.numpy()) + + +class TestSimpleUndiagonal2(unittest.TestCase): + """ + EinsumOp support undiagonalize. + """ + + def test_shape(self): + paddle.disable_static() + A = paddle.to_tensor(np.array([1.0, 2.0])) + B = paddle.to_tensor(np.array([1.0, 1.0])) + A_expect = paddle.to_tensor([[2.0, 0.0], [0.0, 4.0]]) + A_actual = paddle.einsum('i,j->ii', A, B) + np.array_equal(A_expect.numpy(), A_actual.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 5c792f8fe0..19a63d515b 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -727,14 +727,6 @@ def preprocess(equation, *operands): '...' in lhs and '...' not in rhs ), 'Invalid equation: missing ellipsis in output labels.' - assert not ( - len(list(filter(has_duplicated_labels, lhs.split(',')))) > 0 - ), 'Duplicate labels are not supported.' - - assert not has_duplicated_labels( - rhs - ), 'Invalid equation: duplicate output labels are found.' - return lhs, rhs, labels -- GitLab