未验证 提交 6e1c14e3 编写于 作者: X xiongkun 提交者: GitHub

[Einsum] Einsum support repeated labels. (#47290)

* add unittest for einsum-v2-trace and diagonal

* repeat labels.

* einsum support repeated labels.

* forward is ok for diagonal and undiagonalized.
TODO: check backward is ok by our theorem.

* backward is ok!

* fix by PR suggestions.

* fix ci error

* fix ci error

* fix ci warning
上级 266283b2
......@@ -835,6 +835,7 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::vector<std::string> input_strs;
std::string right;
ParseEinsumEquation(equation,
input_dims,
......@@ -845,7 +846,8 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);
&right,
&input_strs);
VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
......
......@@ -90,4 +90,7 @@ PD_REGISTER_KERNEL(diagonal_grad,
float,
double,
int,
int64_t) {}
int64_t,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -88,4 +88,6 @@ PD_REGISTER_KERNEL(diagonal,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -39,4 +40,17 @@ void DiagonalKernel(const Context& dev_ctx,
int axis1,
int axis2,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Diagonal(const Context& dev_ctx,
const DenseTensor& x,
int offset,
int axis1,
int axis2) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
DiagonalInferMeta(x, offset, axis1, axis2, &meta_out);
DiagonalKernel<T, Context>(dev_ctx, x, offset, axis1, axis2, &dense_out);
return dense_out;
}
} // namespace phi
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
namespace phi {
......@@ -27,6 +28,21 @@ void FillDiagonalTensorKernel(const Context& ctx,
int dim2,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor FillDiagonalTensor(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
int64_t offset,
int dim1,
int dim2) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
FillDiagonalTensorInferMeta(x, y, offset, dim1, dim2, &meta_out);
FillDiagonalTensorKernel<T, Context>(
ctx, x, y, offset, dim1, dim2, &dense_out);
return dense_out;
}
void CalMatDims(phi::DDim out_dims,
int dim1,
int dim2,
......
......@@ -166,4 +166,9 @@ PD_REGISTER_KERNEL(diagonal_grad,
float,
double,
int,
int64_t) {}
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -163,4 +163,8 @@ PD_REGISTER_KERNEL(diagonal,
double,
int,
int64_t,
bool) {}
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -109,6 +109,7 @@ PD_REGISTER_KERNEL(fill_diagonal_tensor_grad,
int8_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
......@@ -131,6 +131,7 @@ PD_REGISTER_KERNEL(fill_diagonal_tensor,
int8_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
......@@ -20,15 +20,20 @@
#include "paddle/utils/string/string_helper.h"
namespace phi {
template <typename T, typename Context>
DenseTensor PerformTileAndReduction(const Context& dev_ctx,
const LabelMap& label2type,
const LabelMap& label2shape,
const std::vector<int>& broadcast_dims,
const std::vector<int>& ellipsis_dims,
std::string op_label, // value pass
DenseTensor& t) { // NOLINT
ReplaceEllipsis(op_label);
std::string equ, // value pass
DenseTensor& t) { // NOLINT
auto tmp_label = equ;
ReplaceEllipsis(tmp_label);
auto tmp_union = unique_labels(tmp_label);
auto op_label = std::string(tmp_union.begin(), tmp_union.end());
VLOG(5) << "Start PerformTileAndReduction" << equ;
DenseTensor ret;
std::vector<int> repeat_times;
std::vector<int> resize_dims;
......@@ -61,6 +66,8 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
})) {
after_tile = t;
} else {
VLOG(4) << "do TileKernel with repeat_times="
<< paddle::string::join_strings(repeat_times, ",");
TileKernel<T, Context>(dev_ctx, t, repeat_times, &after_tile);
}
size_t n_ellipsis_idx = op_label.find(".", 0);
......@@ -92,7 +99,11 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx,
VLOG(5) << "PermformTileAndReduction: recover shape: "
<< paddle::string::join_strings(recover_shape, ",");
ret.Resize(make_ddim(recover_shape));
return ret;
// undiagonalize by einsum equation. only contain undiagonal operations.
DenseTensor out;
VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ;
EinsumKernel<T, Context>(dev_ctx, {&ret}, op_label + "->" + equ, &out);
return out;
}
template <typename T, typename Context>
......@@ -115,6 +126,7 @@ void EinsumGradKernel(const Context& dev_ctx,
for (auto& i : x) {
input_dims.push_back(i->dims());
}
std::vector<std::string> input_strs;
std::string right;
ParseEinsumEquation(equation,
input_dims,
......@@ -125,13 +137,15 @@ void EinsumGradKernel(const Context& dev_ctx,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);
&right,
&input_strs);
auto gather_labels_except_reduction = [&labeltype](std::string all) {
std::string res("");
for (auto c : all)
if (labeltype[static_cast<int>(c)] != LabelType::Reduction) res += c;
return res;
auto tmp_unique = unique_labels(res);
return std::string(tmp_unique.begin(), tmp_unique.end());
};
if (x.size() == 1) { // Unary
auto splits = paddle::string::split_string(equation, "->");
......@@ -141,6 +155,7 @@ void EinsumGradKernel(const Context& dev_ctx,
auto new_operands = std::vector<const DenseTensor*>();
new_operands.push_back(&out_grad);
DenseTensor before_tile;
VLOG(5) << "new_equation is " << new_equation;
EinsumKernel<T, Context>(dev_ctx, new_operands, new_equation, &before_tile);
*(x_grad[0]) = PerformTileAndReduction<T, Context>(dev_ctx,
labeltype,
......
......@@ -18,6 +18,9 @@
#include "glog/logging.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/diagonal_kernel.h"
#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
......@@ -89,6 +92,9 @@ class LabelMap {
if (label == '.') i = N - 1;
return map[i];
}
bool exist(char label) { return !is_default(label); }
private:
// non-exist is present by is_default
bool is_default(char label) {
return (*this)[static_cast<int>(label)] == default_value;
......@@ -117,8 +123,9 @@ inline static void ReplaceEllipsis(std::string& s) { // NOLINT
}
}
inline std::vector<char> union_labels(const std::vector<char>& a,
const std::vector<char>& b) {
template <typename CharIterable1, typename CharIterable2>
inline std::vector<char> union_labels(const CharIterable1& a,
const CharIterable2& b) {
LabelMap counter(0);
std::vector<char> res;
auto f = [&](char c) {
......@@ -132,6 +139,11 @@ inline std::vector<char> union_labels(const std::vector<char>& a,
return res;
}
template <typename CharIterable>
inline std::vector<char> unique_labels(const CharIterable& a) {
return union_labels(a, CharIterable());
}
// Apply transforms to all_labels and get another all_labels
inline std::vector<char> TransformLabelsOrder(
const std::vector<char>& all_labels,
......@@ -160,9 +172,9 @@ inline static void GlobalInfo(const std::vector<std::string>& op_labels,
}
for (auto& op : op_labels) {
for (auto& ch : op) { // char
for (auto& ch : unique_labels(op)) { // char
int c = ch;
if (counter.is_default(c)) {
if (!counter.exist(c)) {
all.push_back(ch);
}
counter[c] += 1;
......@@ -238,7 +250,7 @@ inline static void InferLabelShape(const std::vector<std::string>& op_labels,
v = op_dim[dim_ptr];
dim_ptr++;
}
} else if (labelshape->is_default(c) || (*labelshape)[c] == -1) {
} else if (!labelshape->exist(c) || (*labelshape)[c] == -1) {
(*labelshape)[c] = op_dim[dim_ptr];
dim_ptr++;
} else if (op_dim[dim_ptr] != -1) {
......@@ -270,12 +282,15 @@ inline static void InferLabelShape(const std::vector<std::string>& op_labels,
<< paddle::string::join_strings(*broadcast_dims, ",");
}
inline static void InferLabelPerm(const std::string& op,
template <class CharIterable>
inline static void InferLabelPerm(const CharIterable& op,
int n_broadcast,
LabelMap* label2perm) {
int cur = 0;
for (int c : op) {
(*label2perm)[c] = cur;
if (!label2perm->exist(
c)) // can appear repeatly. we just record the first position.
(*label2perm)[c] = cur;
if (c == '.') {
cur += n_broadcast;
} else {
......@@ -308,15 +323,21 @@ inline static void ParseEinsumEquation(
std::vector<std::vector<int>>* ellipsis_dims,
std::vector<int>* broadcast_dims,
std::vector<int>* output_dims,
std::string* right) {
std::string* right,
std::vector<std::string>* input_strs) {
VLOG(5) << "Start ParseEinsumEquation";
auto results = paddle::string::split_string(equation, "->");
auto left = results[0];
ReplaceEllipsis(left);
*right = results[1].substr(1);
ReplaceEllipsis(*right);
auto op_labels = paddle::string::split_string(left, ",");
// split_string("i,") -> ["i"], we expect 2 op_labels.
if (left[left.size() - 1] == ',') op_labels.push_back("");
// split_string("i,") -> ["i"], we push back a "".
// split_string("->") -> [], we push back a "".
if (op_labels.size() == 0)
op_labels.push_back("");
else if (left[left.size() - 1] == ',')
op_labels.push_back("");
std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis);
GlobalInfo(op_labels, *right, labeltype, all_labels);
InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims);
......@@ -327,8 +348,8 @@ inline static void ParseEinsumEquation(
for (size_t i = 0; i < inputs.size(); ++i) {
InferLabelPerm(
op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i]));
(*input_strs).push_back(std::move(op_labels[i]));
}
VLOG(5) << "Einsum Infershape: end";
}
template <typename T>
......@@ -371,20 +392,124 @@ std::vector<T> GetShapeByType(const std::vector<char>& all_labels,
return res;
}
inline static std::vector<int> perm_moveto(int n, int from, int to) {
// a permution means moving `from` to `to`.
/*
f => t permtation
--------------------
0 1 2 3 4 5
5 => 2 : 0 2 5 2 3 4
2 => 5 : 0 1 3 4 5 2
we can conclude the following rules.
*/
if (from < 0) from = n + from;
if (to < 0) to = n + to;
std::vector<int> res(n);
for (int i = 0; i < n; ++i) {
res[i] = i;
}
res[to] = from;
auto offset = from > to ? -1 : 1;
auto start = from > to ? to + 1 : from;
auto end = from > to ? from : to - 1;
for (int i = start; i <= end; ++i) {
res[i] += offset;
}
return res;
}
template <typename T, typename Context>
DenseTensor PerformReduction(const Context& dev_ctx,
const DenseTensor& tensor,
const LabelMap& label2perm,
const std::vector<char>& all_labels,
const std::vector<int>& ellipsis,
const LabelMap& label2type) {
DenseTensor Undiagonal(const Context& dev_ctx,
const DenseTensor& tensor,
size_t insert_pos,
size_t axis) {
// tensor with shape (3, 4, 5, 2, 1), insert_pos = 5, axis = 2.
// output is (3, 4, 5, 2, 1, 5)
VLOG(5) << "Start undiagonal with args: insert_pos = " << insert_pos
<< ", axis = " << axis;
std::vector<int> shape(tensor.dims().size() + 1);
int point = 0; // point to the tensor.dims()
for (size_t i = 0; i < shape.size(); ++i) {
if (i == insert_pos)
shape[i] = tensor.dims()[axis];
else
shape[i] = tensor.dims()[point++];
}
auto zeros = Full<T, Context>(dev_ctx, shape, 0);
auto diags = Transpose<T, Context>(
dev_ctx, tensor, perm_moveto(tensor.dims().size(), axis, -1));
return FillDiagonalTensor<T, Context>(
dev_ctx, zeros, diags, 0, insert_pos, axis + (insert_pos <= axis));
}
template <typename T, typename Context>
DenseTensor PerformUndiagonal(const Context& dev_ctx,
const DenseTensor& tensor,
int n_broadcast,
const std::string& equ) {
// if the equ is 'iijjkij', then the tensor must be 'ijk', so we have enough
// information to do un-diagonal with equ.
auto res = tensor;
LabelMap label2perm(-1);
InferLabelPerm(equ, n_broadcast, &label2perm);
// Un-Diagonal
int tot =
equ.size() + n_broadcast + (equ.find(".") != std::string::npos ? -1 : 0);
int cur = tot - 1;
for (auto it = equ.rbegin(); it != equ.rend(); ++it) {
char c = *it;
if (c == '.') {
cur -= n_broadcast;
} else {
if (cur != label2perm[c]) {
// do diagonal, followed by movedim().
auto insert_pos = cur - tot + res.dims().size() + 1;
res = Undiagonal<T, Context>(dev_ctx, res, insert_pos, label2perm[c]);
}
--cur;
}
}
return res;
}
template <typename T, typename Context>
DenseTensor PerformDiagonalAndReduction(const Context& dev_ctx,
const DenseTensor& tensor,
const std::string& equ,
const LabelMap& label2perm,
const std::vector<char>& all_labels,
const std::vector<int>& ellipsis,
const LabelMap& label2type) {
auto res = tensor;
// Diagonal
int tot = equ.size() + ellipsis.size() +
(equ.find(".") != std::string::npos ? -1 : 0);
int cur = tot - 1;
for (auto it = equ.rbegin(); it != equ.rend(); ++it) {
char c = *it;
if (c == '.') {
cur -= ellipsis.size();
} else {
if (cur != label2perm[c]) {
// do diagonal, followed by movedim().
VLOG(5) << "Do diagonal with shape="
<< paddle::string::join_strings(vectorize<int>(res.dims()), ',')
<< ", axis1=" << cur << ", axis2=" << label2perm[c];
res = Diagonal<T, Context>(dev_ctx, res, 0, cur, label2perm[c]);
res = Transpose<T, Context>(
dev_ctx, res, perm_moveto(res.dims().size(), -1, label2perm[c]));
}
--cur;
}
}
// reduction
auto indices = GetLabelIndexByType<int64_t>(
all_labels, label2type, label2perm, ellipsis, LabelType::Reduction);
VLOG(5) << "call PerformReduction: with axis: "
VLOG(5) << "call PerformDiagonalAndReduction: with axis: "
<< paddle::string::join_strings(indices, ",");
if (indices.size() == 0) return tensor;
if (indices.size() == 0) return res;
return Sum<T, Context>(
dev_ctx, tensor, phi::IntArray(indices), tensor.dtype(), true);
dev_ctx, res, phi::IntArray(indices), res.dtype(), true);
}
inline bool is_no_need_transpose(const std::vector<int>& axis) {
......@@ -415,8 +540,8 @@ DenseTensor PerformTranspose(const Context& dev_ctx,
template <typename T, typename Context>
DenseTensor PerformContraction(
const Context& dev_ctx,
const DenseTensor& A,
const DenseTensor& B,
const std::vector<const DenseTensor*>& operands,
const std::vector<std::string>& input_strs,
const std::vector<LabelMap>& label2perm,
const std::vector<char>& all_labels,
const LabelMap& label2type,
......@@ -467,8 +592,14 @@ DenseTensor PerformContraction(
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
} else {
auto reduct_t = PerformReduction<T, Context>(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
auto reduct_t =
PerformDiagonalAndReduction<T, Context>(dev_ctx,
t,
input_strs[operand_idx],
perm,
all_labels,
ellipsis,
label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (cache[operand_idx] != nullptr)
......@@ -499,10 +630,19 @@ DenseTensor PerformContraction(
};
// Reduction, Reshape and Matmul
auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0], 0);
auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1], 1);
auto after_contraction =
Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, false);
DenseTensor after_contraction;
if (operands.size() == 2) {
auto trans_a =
preprocess(*(operands[0]), label2perm[0], ellipsis_dims[0], 0);
auto trans_b =
preprocess(*(operands[1]), label2perm[1], ellipsis_dims[1], 1);
after_contraction =
Matmul<T, Context>(dev_ctx, trans_a, trans_b, false, false);
} else if (operands.size() == 1) {
after_contraction =
preprocess(*(operands[0]), label2perm[0], ellipsis_dims[0], 0);
}
if (recover_dim.size() == 0) recover_dim.push_back(1);
VLOG(5) << "PerformContraction: recover_dim: "
<< paddle::string::join_strings(recover_dim, ",");
after_contraction.Resize(make_ddim(recover_dim));
......@@ -510,12 +650,11 @@ DenseTensor PerformContraction(
}
template <typename T, typename Context>
void TransposeToOutput(const Context& dev_ctx,
const DenseTensor& to_trans,
const std::string& right,
const std::vector<char>& all_labels,
int n_broadcast_dims,
DenseTensor* output) {
DenseTensor TransposeToOutput(const Context& dev_ctx,
const DenseTensor& to_trans,
const std::vector<char>& right,
const std::vector<char>& all_labels,
int n_broadcast_dims) {
std::vector<int> axis;
int offset = 0;
if (std::find(all_labels.begin(), all_labels.end(), '.') !=
......@@ -534,12 +673,11 @@ void TransposeToOutput(const Context& dev_ctx,
}
}
if (is_no_need_transpose(axis)) {
output->ShareBufferWith(to_trans);
return;
return to_trans;
}
VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ",");
TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
return Transpose<T, Context>(dev_ctx, to_trans, axis);
}
template <typename T, typename Context>
......@@ -550,6 +688,7 @@ void EinsumKernelImpl(const Context& dev_ctx,
DenseTensor* out,
std::vector<DenseTensor*> cache,
bool is_forward = true) {
VLOG(5) << "Start EinsumKernelImpl";
ValidationCheck(equation);
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
......@@ -564,6 +703,7 @@ void EinsumKernelImpl(const Context& dev_ctx,
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::vector<std::string> input_strs;
std::string right;
if (!is_forward) {
all_labels = forward_all_labels;
......@@ -577,57 +717,32 @@ void EinsumKernelImpl(const Context& dev_ctx,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);
out->Resize(make_ddim(output_dims));
if (inputs.size() == 2) {
auto& A = inputs[0];
auto& B = inputs[1];
// Reduction and Contract Procedure
auto after_contraction = PerformContraction<T, Context>(dev_ctx,
*A,
*B,
label2perms,
all_labels,
labeltype,
labelshape,
ellipsis_dims,
broadcast_dims,
cache,
!is_forward);
TransposeToOutput<T, Context>(dev_ctx,
after_contraction,
right,
all_labels,
broadcast_dims.size(),
out);
// Reshape Procedure
} else if (inputs.size() == 1) {
if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if
// loading the program from v2.3.0
(*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward
// we can only see cached tensor.
}
auto reduce_A = PerformReduction<T, Context>(dev_ctx,
*inputs[0],
label2perms[0],
all_labels,
ellipsis_dims[0],
labeltype);
std::vector<char> right_labels;
for (auto c : right) right_labels.push_back(c);
right_labels = union_labels(right_labels, all_labels);
*out = PerformTranspose<T, Context>(dev_ctx,
reduce_A,
label2perms[0],
right_labels,
broadcast_dims,
labeltype);
out->Resize(make_ddim(output_dims));
} else {
&right,
&input_strs);
if (inputs.size() > 2) {
PADDLE_THROW(phi::errors::InvalidArgument(
"EinsumOp kernel only support len(operands) between (0, 2]. Use "
"opt_einsum first to convert multi-variable to binary-variable."));
}
auto after_contraction = PerformContraction<T, Context>(dev_ctx,
inputs,
input_strs,
label2perms,
all_labels,
labeltype,
labelshape,
ellipsis_dims,
broadcast_dims,
cache,
!is_forward);
*out = TransposeToOutput<T, Context>(dev_ctx,
after_contraction,
unique_labels(right),
all_labels,
broadcast_dims.size());
*out = PerformUndiagonal<T, Context>(
dev_ctx, *out, broadcast_dims.size(), right);
out->Resize(make_ddim(output_dims));
}
template <typename T, typename Context>
......
......@@ -154,5 +154,54 @@ class TestEinsumWithBroadcast6(TestEinsumBinary):
self.equation = "i,i->"
class TestEinsumWithDiagonal(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 10)]
self.types = [np.float64]
self.equation = "ii->"
class TestEinsumWithDiagonal2(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(10, 3, 10)]
self.types = [np.float64]
self.equation = "iji->j"
class TestEinsumWithDiagonal3(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(5, 3, 2, 1, 4, 5)]
self.types = [np.float64]
self.equation = "a...a->..."
class TestEinsumWithDiagonal4(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(5, 3, 2, 1, 4, 5)]
self.types = [np.float64]
self.equation = "a...a->a..."
class TestEinsumWithDiagonal5(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(8, 8, 8)]
self.types = [np.float64]
self.equation = "aaa->a"
class TestEinsumWithDiagonal6(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(3, 5, 7, 3), (5, 7, 5, 7)]
self.types = [np.float64, np.float64]
self.equation = "ijki,jkjk->ik"
class TestEinsumWithDiagonal8(TestEinsumBinary):
def set_mandatory(self):
self.shapes = [(3, 5, 7, 3), (5, 7, 5, 7)]
self.types = [np.float64, np.float64]
self.equation = "ijki,jkjk->"
if __name__ == "__main__":
unittest.main()
......@@ -41,22 +41,6 @@ class TestErrors(unittest.TestCase):
def setUp(self):
pass
def test_diagonalize_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
with self.assertRaisesRegex(
AssertionError, ('Duplicate labels are not supported.')
):
paddle.einsum('...ii->...i', a)
with self.assertRaisesRegex(
AssertionError, ('Duplicate labels are not supported.')
):
paddle.einsum('i...i', a)
with self.assertRaisesRegex(
AssertionError, ('Duplicate labels are not supported.')
):
paddle.einsum('i...i->i...', a)
def test_param_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
......@@ -126,11 +110,6 @@ class TestErrors(unittest.TestCase):
("Invalid equation: missing ellipsis in output labels."),
):
paddle.einsum('i...->i', a)
with self.assertRaisesRegex(
AssertionError,
("Invalid equation: duplicate output labels are found."),
):
paddle.einsum('i...->i...i', a)
with self.assertRaisesRegex(
AssertionError,
(
......@@ -162,6 +141,13 @@ class TestEinsum(unittest.TestCase):
"I": np.random.rand(2, 2),
"J": np.random.rand(1, 3, 5),
"K": np.random.rand(1, 2, 3, 4),
"X": np.random.rand(5, 5),
"L": np.random.rand(5, 10, 5),
"M": np.random.rand(5, 3, 2, 1, 4, 5),
"N": np.random.rand(5, 5, 5),
"O": np.random.rand(3, 5, 7, 3),
"P": np.random.rand(5, 7, 5, 7),
"S": np.random.rand(4, 3, 4, 4),
}
def _get_place(self, force_to_use_cpu=False):
......@@ -207,14 +193,54 @@ class TestEinsum(unittest.TestCase):
self.check_output_equal(result.numpy(), expected_result)
class TestEinsumVectorDot(TestEinsum):
class TestEinsumTraceDiag1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ii->", "data": ["X"]}
class TestEinsumTraceDiag2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "iji->j", "data": ["L"]}
class TestEinsumTraceDiag3(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "a...a->...", "data": ["M"]}
class TestEinsumTraceDiag4(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "a...a->a...", "data": ["M"]}
class TestEinsumTraceDiag5(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "aaa->a", "data": ["N"]}
# Numpy don't support i->ii, but paddle.einsum support.
# class TestEinsumTraceDiag6(TestEinsum):
# def setUp(self):
# self.sample = {"paradigm": "i->iii", "data": ["x"]}
# class TestEinsumTraceDiag7(TestEinsum):
# def setUp(self):
# self.sample = {"paradigm": "i...->i...i", "data": ["S"]}
class TestEinsumTraceDiag2Ops(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijki,jkjk->ik", "data": ["O", "P"]}
class TestEinsumIdentity(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,i->", "data": ["x", "x"]}
self.sample = {"paradigm": "...->...", "data": ["N"]}
class TestEinsumVectorMul(TestEinsum):
class TestEinsumElementwiseProduct(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,i->i", "data": ["x", "x"]}
self.sample = {"paradigm": "...,...->...", "data": ["N", "N"]}
class TestEinsumVectorOuter(TestEinsum):
......@@ -436,37 +462,12 @@ class TestNumpyTests(unittest.TestCase):
self.check_output("...,...", a, a)
self.check_output("i,i", a, a)
# TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
# p = np.ones((10, 2)).astype('float')
# q = np.ones((1, 2)).astype('float')
# self.check_output('ij,ij->j', p, q)
# TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
# x = np.array([2., 3.]).astype('float')
# y = np.array([4.]).astype('float')
# self.check_output("i, i", x, y)
# TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
# p = np.ones((1, 5)) / 2
# q = np.ones((5, 5)) / 2
# self.check_output("...ij,...jk->...ik", p, p)
# self.check_output("...ij,...jk->...ik", p, q)
x = np.eye(2).astype('float')
y = np.ones(2).astype('float')
self.check_output("ji,i->", x, y)
self.check_output("i,ij->", y, x)
self.check_output("ij,i->", x, y)
def test_large_nops(self):
pass
# TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
# a = np.arange(4 * 3 * 1 * 4).reshape(4, 3, 1, 4).astype('float')
# self.check_output('a...b,b...c,c...d', a, a, a)
# self.check_output('a...b,b...c,c...a', a, a, a)
# self.check_output('a...b,b...c,c...a', a, a, a)
# self.check_output('...ab,...ba,...ab,...ab', a, a, a, a)
def test_static_graph(self):
paddle.enable_static()
fluid = paddle.fluid
......@@ -569,5 +570,32 @@ class TestComplex(unittest.TestCase):
c = paddle.einsum('xy,yz->xz', a, b)
class TestSimpleUndiagonal(unittest.TestCase):
"""
EinsumOp support undiagonalize.
"""
def test_shape(self):
paddle.disable_static()
A = paddle.to_tensor(np.array([1.0, 2.0]))
A_expect = paddle.to_tensor([[1.0, 0.0], [0.0, 2.0]])
A_actual = paddle.einsum('i->ii', A)
np.array_equal(A_expect.numpy(), A_actual.numpy())
class TestSimpleUndiagonal2(unittest.TestCase):
"""
EinsumOp support undiagonalize.
"""
def test_shape(self):
paddle.disable_static()
A = paddle.to_tensor(np.array([1.0, 2.0]))
B = paddle.to_tensor(np.array([1.0, 1.0]))
A_expect = paddle.to_tensor([[2.0, 0.0], [0.0, 4.0]])
A_actual = paddle.einsum('i,j->ii', A, B)
np.array_equal(A_expect.numpy(), A_actual.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -727,14 +727,6 @@ def preprocess(equation, *operands):
'...' in lhs and '...' not in rhs
), 'Invalid equation: missing ellipsis in output labels.'
assert not (
len(list(filter(has_duplicated_labels, lhs.split(',')))) > 0
), 'Duplicate labels are not supported.'
assert not has_duplicated_labels(
rhs
), 'Invalid equation: duplicate output labels are found.'
return lhs, rhs, labels
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册