未验证 提交 32cae24c 编写于 作者: X xiongkun 提交者: GitHub

Make einsum_v2 support multi-operands (#42327)

* Extend python einsum interface to make einsum_v2 support multi-operands and switch it to default.

* add opt_einsum dependence

* add yaml and support eager model

* fix by code review
上级 21d94dd3
......@@ -18,7 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -85,7 +85,7 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor,
PD_INFER_META(phi::EinsumInferShape));
PD_INFER_META(phi::EinsumInferMeta));
REGISTER_OPERATOR(einsum, ops::EinsumOp, ops::EinsumOpMaker,
EinsumInferShapeFunctor,
......
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/einsum_impl.h"
namespace phi {
......@@ -398,6 +399,45 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim);
}
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> broadcast_dims;
std::vector<int> output_dims;
std::vector<std::vector<int>> ellipsis_dims(2);
std::vector<DDim> input_dims;
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::string right;
ParseEinsumEquation(equation,
input_dims,
&labelshape,
&labeltype,
&all_labels,
&label2perms,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);
VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
VLOG(3) << "Einsum Infershape: equation:" << equation;
VLOG(3) << "Einsum Infershape: all_labels:"
<< paddle::string::join_strings(all_labels, ",");
VLOG(3) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype);
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
}
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out) {
......
......@@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w,
MetaTensor* out_v);
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out);
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out);
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
......@@ -21,6 +20,7 @@
#include "paddle/utils/string/string_helper.h"
namespace phi {
// check the validation of the Einsum equation.
// 1. the label must between 'a' - 'z'.
// 2. the dim of the same label must be same.
......@@ -302,45 +302,6 @@ inline static void ParseEinsumEquation(
}
}
inline void EinsumInferShape(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
std::vector<LabelMap> label2perms(inputs.size(), LabelMap(-1));
std::vector<char> all_labels;
std::vector<int> broadcast_dims;
std::vector<int> output_dims;
std::vector<std::vector<int>> ellipsis_dims(2);
std::vector<DDim> input_dims;
for (auto& i : inputs) {
input_dims.push_back(i->dims());
}
std::string right;
ParseEinsumEquation(equation,
input_dims,
&labelshape,
&labeltype,
&all_labels,
&label2perms,
&ellipsis_dims,
&broadcast_dims,
&output_dims,
&right);
VLOG(3) << "Einsum Infershape: input dims:"
<< paddle::string::join_strings(input_dims, "\n");
VLOG(3) << "Einsum Infershape: equation:" << equation;
VLOG(3) << "Einsum Infershape: all_labels:"
<< paddle::string::join_strings(all_labels, ",");
VLOG(3) << "Einsum Infershape: output dims:"
<< paddle::string::join_strings(output_dims, ",");
VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype);
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
}
template <typename T>
std::vector<T> GetLabelIndexByType(const std::vector<char>& all_labels,
const LabelMap& type,
......@@ -394,6 +355,13 @@ DenseTensor PerformReduction(const Context& dev_ctx,
return Sum<T, Context>(dev_ctx, tensor, indices, tensor.dtype(), true);
}
inline bool is_no_need_transpose(const std::vector<int>& axis) {
for (size_t i = 0; i < axis.size(); ++i) {
if (i != static_cast<size_t>(axis[i])) return false;
}
return true;
}
template <typename T, typename Context>
DenseTensor PerformTranspose(const Context& dev_ctx,
const DenseTensor& tensor,
......@@ -401,12 +369,6 @@ DenseTensor PerformTranspose(const Context& dev_ctx,
const std::vector<char>& all_labels,
const std::vector<int>& ellipsis,
const LabelMap& label2type) {
auto is_no_need_transpose = [](std::vector<int>& axis) {
for (size_t i = 0; i < axis.size(); ++i) {
if (i != size_t(axis[i])) return false;
}
return true;
};
auto axis = GetLabelIndexByType<int>(
all_labels, label2type, label2perm, ellipsis, LabelType::ALL_TYPE);
VLOG(5) << "PerformTranspose: " << paddle::string::join_strings(axis, ",");
......@@ -496,9 +458,9 @@ void TransposeToOutput(const Context& dev_ctx,
axis.push_back(it - all_labels.begin() + offset);
}
}
if (is_no_need_transpose(axis)) return output->ShareBufferWith(to_trans);
VLOG(5) << "call TransposeToOutput: with axis: "
<< paddle::string::join_strings(axis, ",");
if (axis.size() == 0) return output->ShareBufferWith(to_trans);
return TransposeKernel<T, Context>(dev_ctx, to_trans, axis, output);
}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import contextlib
import unittest
import paddle
from paddle.fluid import core
import os
os.environ['FLAGS_new_einsum'] = "1"
def error_trans(func, *args, **kargs):
"""
transport C++ exception into Python exception.
because einsum_v2 raise different exception with einsum_v1.
"""
try:
out = func(*args, **kargs)
except ValueError as e:
if "Same label have different shapes" in str(e):
raise AssertionError("Invalid operands: label i "
"corresponds to non-broadcastable dimensions.")
class TestErrors(unittest.TestCase):
def setUp(self):
pass
def test_diagonalize_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
with self.assertRaisesRegex(AssertionError,
('Duplicate labels are not supported.')):
paddle.einsum('...ii->...i', a)
with self.assertRaisesRegex(AssertionError,
('Duplicate labels are not supported.')):
paddle.einsum('i...i', a)
with self.assertRaisesRegex(AssertionError,
('Duplicate labels are not supported.')):
paddle.einsum('i...i->i...', a)
def test_param_errors(self):
a = np.arange(4 * 3 * 4 * 4).reshape(4, 3, 4, 4).astype('float')
a = paddle.to_tensor(a)
with self.assertRaisesRegex(
AssertionError,
("Required at least one operand in Einsum API, but received 0 ")):
paddle.einsum('ijk')
with self.assertRaisesRegex(AssertionError, (
'Invalid equation: multiple `->` were found.')):
paddle.einsum('i -> j -> k', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the number of operands is 2, "
"but found 3 segments in the label equation.")):
paddle.einsum('i,j,k', a, a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the number of operands is 2, "
"but found 1 segments in the label equation.")):
paddle.einsum('ij -> k', a, a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the number of operands is 1, "
"but found 2 segments in the label equation.")):
paddle.einsum('i, -> k', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the label string '' misses dimensions.")):
paddle.einsum('->', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: the label string 'i' misses dimensions.")):
paddle.einsum('i', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: _ is not a valid label, "
"which should be letters.")):
paddle.einsum('i_', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: `.` is found outside of an ellipsis.")):
paddle.einsum('i..j', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: `.` is found outside of an ellipsis.")):
paddle.einsum('...k...', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: missing ellipsis in output labels.")):
paddle.einsum('i...->i', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid equation: duplicate output labels are found.")):
paddle.einsum('i...->i...i', a)
with self.assertRaisesRegex(AssertionError, (
"Invalid operands: label i "
"corresponds to non-broadcastable dimensions.")):
error_trans(paddle.einsum, 'ij...,ji...', a, a)
class TestEinsum(unittest.TestCase):
@classmethod
def setUpClass(cls):
np.random.seed(12345)
cls.TEST_SAMPLES = {
"a": np.random.rand(1, 1),
"b": np.random.rand(1),
"x": np.random.rand(5),
"y": np.random.rand(7),
"A": np.random.rand(4, 5),
"B": np.random.rand(2, 5),
"C": np.random.rand(3, 7),
"D": np.random.rand(3, 4, 5),
"E": np.random.rand(3, 5, 2),
"F": np.random.rand(2, 4, 5, 3),
"G": np.random.rand(4, 2, 5),
"H": np.random.rand(3, 2, 4),
"I": np.random.rand(2, 2),
"J": np.random.rand(1, 3, 5),
"K": np.random.rand(1, 2, 3, 4),
}
def _get_place(self, force_to_use_cpu=False):
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
def check_output_equal(self, actual, expect, rtol=1.e-5, atol=1.e-8):
error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}'
self.assertTrue(
np.allclose(
actual, expect, rtol=rtol, atol=atol),
error_msg.format(paddle.get_device(), expect, actual,
self.__class__.__name__))
def setUp(self):
self.sample = {"paradigm": "i->", "data": ["x"]}
def test_forward(self):
operands = [
TestEinsum.TEST_SAMPLES[operand] for operand in self.sample["data"]
]
expected_result = np.einsum(self.sample["paradigm"], *operands)
equation = self.sample["paradigm"]
with paddle.fluid.dygraph.guard(
self._get_place(force_to_use_cpu=False)):
pd_operands = [paddle.to_tensor(operand) for operand in operands]
result = paddle.einsum(equation, *pd_operands)
self.check_output_equal(result.numpy(), expected_result)
with paddle.fluid.dygraph.guard(self._get_place(force_to_use_cpu=True)):
pd_operands = [paddle.to_tensor(operand) for operand in operands]
result = paddle.einsum(equation, *pd_operands)
self.check_output_equal(result.numpy(), expected_result)
class TestEinsumVectorDot(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,i->", "data": ["x", "x"]}
class TestEinsumVectorMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,i->i", "data": ["x", "x"]}
class TestEinsumVectorOuter(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i,j->ij", "data": ["x", "y"]}
class TestEinsumMatrixTranspose(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij->ji", "data": ["A"]}
class TestEinsumMatrixRowSum(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij->j", "data": ["A"]}
class TestEinsumMatrixColSum(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij->i", "data": ["A"]}
class TestEinsumMatrixEleMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,ij->ij", "data": ["A", "A"]}
class TestEinsumDegenerateMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j", "data": ["a", "b"]}
class TestEinsumMatrixVecMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,j->i", "data": ["A", "x"]}
class TestEinsumMatrixMul(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,kj->ik", "data": ["A", "B"]}
class TestEinsumMatrixOuter(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,kl->ijkl", "data": ["A", "C"]}
class TestEinsumTensorBMM(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "bij,bjk->bik", "data": ["D", "E"]}
class TestEinsumTensorContract1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,jk->i", "data": ["D", "A"]}
class TestEinsumTensorContract2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,lk->ijl", "data": ["D", "B"]}
class TestEinsumTensorContract3(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "abcd,dfg->abcfg", "data": ["F", "D"]}
class TestEinsumTensorContract4(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,jk->ik", "data": ["D", "A"]}
class TestEinsumTensorContract5(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk,jk->ij", "data": ["D", "A"]}
class TestEinsumTensorContract6(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ik, ijk->j", "data": ["A", "G"]}
class TestEinsumTensorContract7(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijk, ik->jk", "data": ["G", "A"]}
class TestEinsumEllipsis1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "i...->...", "data": ["G"]}
class TestEinsumEllipsis2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ij,...i->j...", "data": ["A", "H"]}
class TestEinsumEllipsis3(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "k...,jk", "data": ["F", "I"]}
class TestEinsumTestEinsumBilinear(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "bn,anm,bm->ba", "data": ["B", "E", "I"]}
class TestEinsumTestEinsumOthers1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijkl, lmn->kmn", "data": ["F", "H"]}
class TestEinsumTestEinsumOthers2(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "ijkl, lmn->ijn", "data": ["F", "H"]}
class TestEinsumBatch1(TestEinsum):
def setUp(self):
self.sample = {"paradigm": "blq,bhlk->bhlqk", "data": ["J", "K"]}
class TestNumpyTests(unittest.TestCase):
def setUp(self):
pass
def _get_place(self, force_to_use_cpu=False):
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
def check_output_equal(self, actual, expect, rtol=1.e-5, atol=1.e-8):
error_msg = 'Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}'
self.assertTrue(
np.allclose(
actual, expect, rtol=rtol, atol=atol),
error_msg.format(paddle.get_device(), expect, actual,
self.__class__.__name__))
def check_output(self, eqn, *ops):
expect = np.einsum(eqn, *ops)
with paddle.fluid.dygraph.guard(
self._get_place(force_to_use_cpu=False)):
pd_operands = [paddle.to_tensor(op) for op in ops]
actual = paddle.einsum(eqn, *pd_operands)
self.check_output_equal(actual.numpy(), expect)
def test_sums(self):
for n in range(1, 17):
a = np.arange(n).astype('float')
self.check_output("i->", a)
for n in range(1, 17):
a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
self.check_output("...i->...", a)
for n in range(1, 17):
a = np.arange(2 * n).reshape(2, n).astype('float')
self.check_output("i...->...", a)
for n in range(1, 17):
a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
self.check_output("i...->...", a)
for n in range(1, 17):
a = np.arange(3 * n).reshape(3, n).astype('float')
b = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
self.check_output("..., ...", a, b)
for n in range(1, 17):
a = np.arange(2 * 3 * n).reshape(2, 3, n).astype('float')
b = np.arange(n).astype('float')
self.check_output("...i, ...i", a, b)
for n in range(1, 11):
a = np.arange(n * 3 * 2).reshape(n, 3, 2).astype('float')
b = np.arange(n).astype('float')
self.check_output("i..., i...", a, b)
for n in range(1, 17):
a = (np.arange(3) + 1).astype('float')
b = (np.arange(n) + 1).astype('float')
self.check_output("i,j", a, b)
for n in range(1, 17):
a = np.arange(4 * n).reshape(4, n).astype('float')
b = np.arange(n).astype('float')
self.check_output("ij, j", a, b)
for n in range(1, 17):
a = np.arange(4 * n).reshape(4, n).astype('float')
b = np.arange(n).astype('float')
self.check_output("ji,j", a.T, b.T)
for n in range(1, 17):
a = np.arange(4 * n).reshape(4, n).astype('float')
b = np.arange(n * 6).reshape(n, 6).astype('float')
self.check_output("ij,jk", a, b)
a = np.arange(12).reshape(3, 4).astype('float')
b = np.arange(20).reshape(4, 5).astype('float')
c = np.arange(30).reshape(5, 6).astype('float')
self.check_output("ij,jk,kl", a, b, c)
a = np.arange(60).reshape(3, 4, 5).astype('float')
b = np.arange(24).reshape(4, 3, 2).astype('float')
self.check_output("ijk, jil -> kl", a, b)
for n in range(1, 25):
a = np.arange(n).astype('float')
self.check_output("...,...", a, a)
self.check_output("i,i", a, a)
# TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
#p = np.ones((10, 2)).astype('float')
#q = np.ones((1, 2)).astype('float')
#self.check_output('ij,ij->j', p, q)
# TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
#x = np.array([2., 3.]).astype('float')
#y = np.array([4.]).astype('float')
#self.check_output("i, i", x, y)
# TODO(@xiongkun): explict-label-broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
#p = np.ones((1, 5)) / 2
#q = np.ones((5, 5)) / 2
#self.check_output("...ij,...jk->...ik", p, p)
#self.check_output("...ij,...jk->...ik", p, q)
x = np.eye(2).astype('float')
y = np.ones(2).astype('float')
self.check_output("ji,i->", x, y)
self.check_output("i,ij->", y, x)
self.check_output("ij,i->", x, y)
def test_large_nops(self):
pass
# TODO(@xiongkun): explict broadcast in EinsumOp is not supported, it's not recommend to use einsum like this.
#a = np.arange(4 * 3 * 1 * 4).reshape(4, 3, 1, 4).astype('float')
#self.check_output('a...b,b...c,c...d', a, a, a)
#self.check_output('a...b,b...c,c...a', a, a, a)
#self.check_output('a...b,b...c,c...a', a, a, a)
#self.check_output('...ab,...ba,...ab,...ab', a, a, a, a)
def test_static_graph(self):
paddle.enable_static()
fluid = paddle.fluid
if fluid.core.is_compiled_with_cuda():
self.place = fluid.CUDAPlace(0)
else:
self.place = fluid.CPUPlace()
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
a = paddle.static.data(
name='a', shape=[3, None, None, None], dtype='float')
b = paddle.static.data(
name='b', shape=[2, None, None, None], dtype='float')
c = paddle.static.data(
name='c', shape=[None, None, 2, None], dtype='float')
d = paddle.static.data(
name='d', shape=[None, None, 5], dtype='float')
e = paddle.static.data(
name='e', shape=[None, 2, None], dtype='float')
outs = []
outs.append(paddle.einsum("ibnd,jbnd->bnij", a, b))
outs.append(paddle.einsum('...ik, ...j', c, d))
outs.append(paddle.einsum('...kj, ...ik', d, e))
outs.append(paddle.einsum('ijk..., ikj', c, e))
outs.append(paddle.einsum('ijk..., ikj->...ij', c, e))
exe = fluid.Executor(self.place)
exe.run(startup)
a = np.arange(72).reshape(3, 2, 3, 4).astype('float')
b = np.arange(48).reshape(2, 2, 3, 4).astype('float')
c = np.arange(48).reshape(2, 3, 2, 4).astype('float')
d = np.arange(30).reshape(2, 3, 5).astype('float')
e = np.arange(12).reshape(2, 2, 3).astype('float')
feeds = {'a': a, 'b': b, 'c': c, 'd': d, 'e': e}
actual = exe.run(main, feed=feeds, fetch_list=[outs])
expect = []
expect.append(np.einsum("ibnd,jbnd->bnij", a, b))
expect.append(np.einsum('...ik, ...j', c, d))
expect.append(np.einsum('...kj, ...ik', d, e))
expect.append(np.einsum('ijk..., ikj', c, e))
expect.append(np.einsum('ijk..., ikj->...ij', c, e))
for a, e in zip(actual, expect):
self.check_output_equal(a, e)
if __name__ == "__main__":
u
......@@ -24,6 +24,10 @@ from ..fluid.framework import _in_legacy_dygraph
from paddle import _C_ops
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
import collections
import string
import opt_einsum
from paddle.common_ops_import import dygraph_only
......@@ -664,7 +668,138 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
return plan
def preprocess(equation, *operands):
"""
check equation / raise error, default right labels generation
"""
equation = equation.replace(" ", "")
nop = len(operands)
assert nop > 0, "Required at least one operand in Einsum API, but received %s " % nop
# Part the equation to left hand side and right hand side
lhs, *rhs = equation.lower().split('->')
assert len(rhs) < 2, "Invalid equation: multiple `->` were found."
labels = parse_labels(lhs, operands)
# Note, we distinguish between 'ij->' and 'ij' by setting rhs to '' and None
rhs = rhs[0] if rhs else None
if rhs is None:
rhs = rhs_inference(lhs)
assert len(lhs.split(',')) == len(operands), (
f"Invalid equation: the number of operands is {len(operands)}, "
f"but found {len(lhs.split(','))} segments in the label equation.")
assert not ('...' in lhs and '...' not in rhs
), f'Invalid equation: missing ellipsis in output labels.'
assert not (len(list(filter(has_duplicated_labels, lhs.split(',')))) > 0
), f'Duplicate labels are not supported.'
assert not has_duplicated_labels(
rhs), f'Invalid equation: duplicate output labels are found.'
return lhs, rhs, labels
def parse_fake_shape(equation, operands, labels):
"""
this shape is just used for operands planning. may differ with the original shape.
for example:
... is replaced by 1
-1 is replaced by 1
Results
-------
list of shape
"""
shaped = collections.namedtuple('shaped', ['shape'])
def fake_shape(label, op):
assert len(op.shape) == len(
label
), "length of shape and length of label must be the same, but received %d != %d" % (
len(op.shape), len(label))
fakes = [s for i, (l, s) in enumerate(zip(label, op.shape)) if l != '.']
fakes = list(map(abs, fakes)) # make -1 -> 1
if '.' in label:
fakes.insert(label.index('.'), 1)
return shaped(fakes)
out = list(map(fake_shape, labels, operands))
return out
def rhs_inference(lhs):
def is_free(key):
return cnt.get(key) == 1 and key not in ['.', ',']
cnt = collections.Counter(lhs)
rhs = "..." if '...' in lhs else ""
rhs = rhs + "".join(filter(is_free, sorted(cnt.elements())))
return rhs
def gen_equation_for_opteinsum(lhs, rhs):
"""
1. gen rhs if rhs is None
2. '...' -> 'A'
"""
def get_used_label(counter):
used = set(counter.elements())
for c in string.ascii_lowercase:
if c not in used: return c
raise ValueError(
"You have used all `a` - `z`, there can't find a unused for einsum optimization"
)
cnt = collections.Counter(lhs)
broadcast_label = get_used_label(cnt)
if rhs is None:
rhs = rhs_inference(lhs)
lhs = lhs.replace("...", broadcast_label)
rhs = rhs.replace("...", broadcast_label)
return lhs + "->" + rhs, broadcast_label
def einsum_v2(equation, *operands):
"""
einsum v2 implementation.
1. Implement C++ EinsumOp.
2. V2 create the EinsumOp to calculate, so just a little verifty work in python.
3. V2 use opt_einsum.contract_path to optimize the multivariable einsum.
"""
n_op = len(operands)
lhs, rhs, labels = preprocess(equation, *operands)
if n_op <= 2:
return gen_einsum_op(lhs + '->' + rhs, *operands)
shapes = parse_fake_shape(lhs, operands, labels)
opt_equation, broadcast_label = gen_equation_for_opteinsum(lhs, rhs)
_, cons = opt_einsum.contract_path(opt_equation, *shapes, einsum_call=True)
var_list = list(operands)
for path in cons:
(a, b), _, eq, *__ = path
assert a > b, "Assume the first var_idx is smaller than the second_idx. opt_einsum can guarantee it."
var_s = [var_list.pop(a), var_list.pop(b)]
eq = eq.replace(broadcast_label, "...")
var_list.append(gen_einsum_op(eq, *var_s))
assert len(
var_list
) == 1, "There must be one elements in list, but received %d." % len(
var_list)
return var_list[0]
def gen_einsum_op(equation, *operands):
"""
EinsumOp Python Interface:
"""
assert len(operands) <= 2, "Only support two operands in EinsumOp."
if in_dygraph_mode():
return _C_ops.final_state_einsum(operands, equation)
if _in_legacy_dygraph():
# dygraph
return _C_ops.einsum(operands, 'equation', equation)
......
......@@ -561,6 +561,16 @@
func : eigh
backward : eigh_grad
- api : einsum
args : (Tensor[] x, str equation)
output : Tensor
infer_meta :
func : EinsumInferMeta
param : [x, equation]
kernel :
func : einsum
backward : einsum_grad
- api : elementwise_pow
args : (Tensor x, Tensor y)
output : Tensor(out)
......
......@@ -475,6 +475,16 @@
data_transform:
skip_transform : out_w, out_w_grad
- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
kernel :
func : einsum_grad
- backward_api : elementwise_pow_grad
forward : elementwise_pow(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis=-1)
......
......@@ -6,3 +6,4 @@ six
decorator
astor
paddle_bfloat==0.1.2
opt_einsum==3.3.0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册