diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cd8d94d6b3893698146396651d82509e96b4406 --- /dev/null +++ b/paddle/fluid/operators/einsum_op.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/impl/einsum_impl.h" + +namespace paddle { +namespace operators { +class EinsumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Operands", "(TensorList), The input tensor of einsum op.") + .AsDuplicable(); + AddOutput("Out", "(Tensor), The output tensor of einsum op."); + AddAttr("equation", + "(string) A einsum equation. such as `ij,jk->ik`" + "There must have `->` and the number of operands in " + "equation must equals the `Operands` length."); + AddComment(R"DOC( +Einsum Operator. + +This operator is used to perform einsum operation for given operands and equation. +)DOC"); + } +}; + +class EinsumGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + auto x_name = "Operands"; + auto x_grad_name = framework::GradVarName(x_name); + ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name)); + ctx->ShareAllLoD(x_name, x_grad_name); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class EinsumGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("einsum_grad"); + retv->SetInput("Operands", this->Input("Operands")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("Operands"), + this->InputGrad("Operands", false)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor, + PD_INFER_META(phi::EinsumInferShape)); + +REGISTER_OPERATOR(einsum, ops::EinsumOp, ops::EinsumOpMaker, + EinsumInferShapeFunctor, + ops::EinsumGradMaker, + ops::EinsumGradMaker); + +REGISTER_OPERATOR(einsum_grad, ops::EinsumGradOp); diff --git a/paddle/phi/kernels/cpu/einsum_grad_kernel.cc b/paddle/phi/kernels/cpu/einsum_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cfc2f92204fc10df5cccf1ab506bf10e8ae801e --- /dev/null +++ b/paddle/phi/kernels/cpu/einsum_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/einsum_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/einsum_grad_impl.h" + +PD_REGISTER_KERNEL( + einsum_grad, CPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e25a65526d8965c0954c2ba13f6a2c28380d3f2 --- /dev/null +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/einsum_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/einsum_impl.h" + +PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} diff --git a/paddle/phi/kernels/einsum_grad_kernel.h b/paddle/phi/kernels/einsum_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5c1970e775825492412239f76a7bf00e22a12d5c --- /dev/null +++ b/paddle/phi/kernels/einsum_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EinsumGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const std::string& equation, + std::vector x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3d9e8feda748de24822ee95304016a5cd8a5676e --- /dev/null +++ b/paddle/phi/kernels/einsum_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EinsumKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/einsum_grad_kernel.cu b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c8a8745f34522ddc902dd8af9bf098714c398a20 --- /dev/null +++ b/paddle/phi/kernels/gpu/einsum_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/einsum_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/einsum_grad_impl.h" + +PD_REGISTER_KERNEL( + einsum_grad, GPU, ALL_LAYOUT, phi::EinsumGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d73e154eb40f7934f99e2b3a718e0dbb64c8d6c5 --- /dev/null +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/einsum_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/einsum_impl.h" + +PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, phi::EinsumKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..bd0143379ce1510e9417cba0209ecb3d251b5f71 --- /dev/null +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -0,0 +1,181 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/impl/einsum_impl.h" +#include "paddle/phi/kernels/tile_kernel.h" +#include "paddle/utils/string/string_helper.h" + +namespace phi { +template +DenseTensor PerformTileAndReduction(const Context& dev_ctx, + const LabelMap& label2type, + const LabelMap& label2shape, + const std::vector& broadcast_dims, + const std::vector& ellipsis_dims, + std::string op_label, // value pass + DenseTensor& t) { // NOLINT + ReplaceEllipsis(op_label); + DenseTensor ret; + std::vector repeat_times; + std::vector resize_dims; + std::vector recover_shape; + for (int c : op_label) { + if (label2type[c] == LabelType::Reduction) { + // '.' can't be Reduction, so we don't deal '.' here. + repeat_times.push_back(label2shape[c]); + resize_dims.push_back(1); + recover_shape.push_back(label2shape[c]); + } else { + if (c != '.') { + resize_dims.push_back(label2shape[c]); + repeat_times.push_back(1); + recover_shape.push_back(label2shape[c]); + } else { + int n_dims = broadcast_dims.size(); + resize_dims.insert( + resize_dims.end(), broadcast_dims.begin(), broadcast_dims.end()); + recover_shape.insert( + recover_shape.end(), ellipsis_dims.begin(), ellipsis_dims.end()); + while (n_dims--) repeat_times.push_back(1); + } + } + } + t.Resize(make_ddim(resize_dims)); + DenseTensor after_tile; + TileKernel(dev_ctx, t, repeat_times, &after_tile); + size_t n_ellipsis_idx = op_label.find(".", 0); + if (n_ellipsis_idx != std::string::npos) { + // may be we need reduce. broadcast_dims is not equal to ellipsis dims. + std::vector to_reduce; + for (size_t i = 0; i < broadcast_dims.size() - ellipsis_dims.size(); ++i) + to_reduce.push_back(i + n_ellipsis_idx); + + int new_offset = + n_ellipsis_idx + broadcast_dims.size() - ellipsis_dims.size(); + for (size_t i = 0; i < ellipsis_dims.size(); ++i) + if (ellipsis_dims[i] == 1) to_reduce.push_back(i + new_offset); + + VLOG(5) << "PermformTileAndReduction: reduce sum axis: " + << paddle::string::join_strings(to_reduce, ","); + if (to_reduce.size() != 0) { + ret = Sum(dev_ctx, + after_tile, + to_reduce, + after_tile.dtype(), + false); // not keep dim. + } else { + ret = after_tile; + } + } else { + ret = after_tile; + } + VLOG(5) << "PermformTileAndReduction: recover shape: " + << paddle::string::join_strings(recover_shape, ","); + ret.Resize(make_ddim(recover_shape)); + return ret; +} + +template +void EinsumGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const std::string& equation, + std::vector x_grad) { + VLOG(5) << "Start EisumGradKernel:"; + LabelMap labelshape(0); + LabelMap labeltype(LabelType::Reduction); + std::vector label2perms(x.size(), LabelMap(-1)); + std::vector all_labels; // order: ABO, AO, BO, AB, Reduce + std::vector> ellipsis_dims(2); + std::vector broadcast_dims; + std::vector output_dims; + + std::vector input_dims; + for (auto& i : x) { + input_dims.push_back(i->dims()); + } + std::string right; + ParseEinsumEquation(equation, + input_dims, + &labelshape, + &labeltype, + &all_labels, + &label2perms, + &ellipsis_dims, + &broadcast_dims, + &output_dims, + &right); + + auto gather_labels_except_reduction = [&labeltype](std::string all) { + std::string res(""); + for (auto c : all) + if (labeltype[static_cast(c)] != LabelType::Reduction) res += c; + return res; + }; + if (x.size() == 1) { // Unary + auto splits = paddle::string::split_string(equation, "->"); + auto left = splits[0]; + right = splits[1].substr(1); + auto new_equation = right + "->" + gather_labels_except_reduction(left); + auto new_operands = std::vector(); + new_operands.push_back(&out_grad); + DenseTensor before_tile; + EinsumKernel(dev_ctx, new_operands, new_equation, &before_tile); + *(x_grad[0]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[0], + left, + before_tile); + } else { + auto splits = paddle::string::split_string(equation, "->"); + auto left = splits[0]; + auto ops = paddle::string::split_string(left, ","); + right = splits[1].substr(1); + + auto equation_for_A = + right + "," + ops[1] + "->" + gather_labels_except_reduction(ops[0]); + auto equation_for_B = + right + "," + ops[0] + "->" + gather_labels_except_reduction(ops[1]); + auto operands_for_A = std::vector(); + auto operands_for_B = std::vector(); + DenseTensor dA, dB; + operands_for_A.push_back(&out_grad); + operands_for_A.push_back(x[1]); + operands_for_B.push_back(&out_grad); + operands_for_B.push_back(x[0]); + + DenseTensor before_tile; + EinsumKernel(dev_ctx, operands_for_A, equation_for_A, &dA); + EinsumKernel(dev_ctx, operands_for_B, equation_for_B, &dB); + *(x_grad[0]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[0], + ops[0], + dA); + *(x_grad[1]) = PerformTileAndReduction(dev_ctx, + labeltype, + labelshape, + broadcast_dims, + ellipsis_dims[1], + ops[1], + dB); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..d4be007a07fc0ed88f966ccfbd309bc7687a4ca1 --- /dev/null +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -0,0 +1,586 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" +#include "paddle/utils/string/string_helper.h" + +namespace phi { +// check the validation of the Einsum equation. +// 1. the label must between 'a' - 'z'. +// 2. the dim of the same label must be same. +// 3. the broad cast dims in two operands is broadcastable. +// 4. there must exist '->' and the default output is complete in python. +// may be we can skip validation check in C++ and just put it in python. +inline static void ValidationCheck(const std::string& equation) { + auto n_part = paddle::string::split_string(equation, "->").size(); + PADDLE_ENFORCE_EQ(n_part, + 2, + phi::errors::InvalidArgument( + "Required at least one `->` in equation of EinsumOp.")); + size_t pos; + auto trimed_equ = equation; + if ((pos = trimed_equ.find("->", 0)) != std::string::npos) { + trimed_equ.replace(pos, 2, "."); + } + auto is_valid_char = [](char c) { + if (c >= 'a' && c <= 'z') return true; + if (c == '.' || c == ',') return true; + return false; + }; + for (auto c : trimed_equ) { + if (!is_valid_char(c)) + PADDLE_THROW(phi::errors::InvalidArgument( + "Found invalid char in equation. Einsum only accept `a`-`z` and `...`" + "but get:`%c`", + c)); + } +} + +enum LabelType { + ALL_TYPE = 0, + Batch = 1, // ABO + Free, // AO, BO + Contraction, // AB + Reduction, // A, B +}; + +// map a label('a' - 'z') -> int, O(1) speed. +class LabelMap { + constexpr static int N = + 26 + 1; // 'a' - 'z' + '.', '.' is for broadcast dims + int default_value; + int map[N]; + + public: + explicit LabelMap(int default_value = 0) { + this->default_value = default_value; + for (int i = 0; i < N; ++i) map[i] = default_value; + } + int& operator[](int label) { + int i = label - 'a'; + if (label == '.') i = N - 1; + return map[i]; + } + int operator[](int label) const { + int i = label - 'a'; + if (label == '.') i = N - 1; + return map[i]; + } + // non-exist is present by is_default + bool is_default(char label) { + return (*this)[static_cast(label)] == default_value; + } +}; + +inline std::string label_to_string(const std::vector& all_labels, + const LabelMap& label2type) { + std::string str; + for (int a : all_labels) { + std::stringstream ss; + ss << label2type[a]; + str += ss.str(); + } + return str; +} + +inline static void ReplaceEllipsis(std::string& s) { // NOLINT + size_t pos; + if ((pos = s.find("...", 0)) != std::string::npos) { + s.replace(pos, 3, "."); + } + // remove all the space in the expression + while ((pos = s.find(" ", 0)) != std::string::npos) { + s.replace(pos, 1, ""); + } +} + +inline std::vector union_labels(const std::vector& a, + const std::vector& b) { + LabelMap counter(0); + std::vector res; + auto f = [&](char c) { + if (counter[static_cast(c)] == 0) { + res.push_back(c); + } + counter[static_cast(c)] += 1; + }; + std::for_each(a.begin(), a.end(), f); + std::for_each(b.begin(), b.end(), f); + return res; +} + +inline static void GlobalInfo(const std::vector& op_labels, + const std::string& right, + LabelMap* label2type, + std::vector* sorted_labels) { + // sorted_labels: ['.', , ] + VLOG(5) << "GlobalInfo: " + << paddle::string::join_strings(*sorted_labels, ","); + std::vector all; + LabelMap counter(0); + for (auto& ch : right) { // char + int c = ch; + (*label2type)[c] = LabelType::Free; + } + + for (auto& op : op_labels) { + for (auto& ch : op) { // char + int c = ch; + if (counter.is_default(c)) { + all.push_back(ch); + } + counter[c] += 1; + if ((*label2type)[c] != LabelType::Free && counter[c] == 2) + (*label2type)[c] = LabelType::Contraction; + else if (counter[c] == 2) + (*label2type)[c] = LabelType::Batch; + } + } + (*label2type)['.'] = LabelType::Batch; + std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { + if ((*label2type)[c] == LabelType::Batch) + sorted_labels->push_back(static_cast(c)); + }); + std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { + if ((*label2type)[c] == LabelType::Free) + sorted_labels->push_back(static_cast(c)); + }); + std::for_each(all.begin(), all.end(), [sorted_labels, label2type](int c) { + if ((*label2type)[c] == LabelType::Contraction) + sorted_labels->push_back(static_cast(c)); + }); + std::for_each(all.begin(), all.end(), [&sorted_labels, label2type](int c) { + if ((*label2type)[c] == LabelType::Reduction) + sorted_labels->push_back(static_cast(c)); + }); + VLOG(5) << "GlobalInfo: sorted_labels before: " + << paddle::string::join_strings(*sorted_labels, ","); + if (counter[static_cast('.')] > 0) { + std::vector tmp; + tmp.push_back('.'); + // push '.' in the front + *sorted_labels = union_labels(tmp, *sorted_labels); + VLOG(5) << "GlobalInfo: sorted_labels after: " + << paddle::string::join_strings(*sorted_labels, ","); + } +} + +inline static void InferLabelShape(const std::vector& op_labels, + const std::vector& inputs, + LabelMap* labelshape, + std::vector>* ellipsis_dims, + std::vector* broadcast_dims) { + VLOG(5) << "Start InferLabelShape"; + int n_broadcast_dims = 0; + for (size_t i = 0; i < op_labels.size(); ++i) { + VLOG(5) << "oplabels: " << op_labels[i]; + int valid_indices = std::count_if(op_labels[i].begin(), + op_labels[i].end(), + [](char c) { return c != '.'; }); + int n_ellipsis = inputs[i].size() - valid_indices; + VLOG(5) << "valid indices and n_ellipsis: " << valid_indices << " " + << n_ellipsis; + ellipsis_dims->at(i).resize(n_ellipsis); + n_broadcast_dims = std::max(n_broadcast_dims, n_ellipsis); + } + VLOG(5) << "InferLabelShape: Broadcast ndims:" << n_broadcast_dims; + *broadcast_dims = std::vector(n_broadcast_dims, 1); + + for (size_t i = 0; i < op_labels.size(); ++i) { + auto& op_str = op_labels[i]; + auto& op_dim = inputs[i]; + int dim_ptr = 0; + for (int c : op_str) { + if (c == '.') { + for (auto& v : ellipsis_dims->at(i)) { + v = op_dim[dim_ptr]; + dim_ptr++; + } + } else if (labelshape->is_default(c) || (*labelshape)[c] == -1) { + (*labelshape)[c] = op_dim[dim_ptr]; + dim_ptr++; + } else { + PADDLE_ENFORCE_EQ( + (*labelshape)[c], + op_dim[dim_ptr], + phi::errors::InvalidArgument( + "Same label have different shapes for label: `%c`", c)); + dim_ptr++; + } + } + } + for (size_t i = 0; i < op_labels.size(); ++i) { + VLOG(5) << "InferLabelShape: Ellipsis ndims:" + << paddle::string::join_strings(ellipsis_dims->at(i), ","); + int idx = n_broadcast_dims - ellipsis_dims->at(i).size(); + for (auto v : ellipsis_dims->at(i)) { + PADDLE_ENFORCE_EQ( + v == 1 || broadcast_dims->at(idx) == 1 || + broadcast_dims->at(idx) == v, + true, + phi::errors::InvalidArgument( + "Ellipsis dims can't broadcasts. Please Check you operands.")); + broadcast_dims->at(idx) = std::max(v, broadcast_dims->at(idx)); + idx += 1; + } + } + VLOG(5) << "InferLabelShape: Broadcast dims:" + << paddle::string::join_strings(*broadcast_dims, ","); +} + +inline static void InferLabelPerm(const std::string& op, + int n_broadcast, + LabelMap* label2perm) { + int cur = 0; + for (int c : op) { + (*label2perm)[c] = cur; + if (c == '.') { + cur += n_broadcast; + } else { + cur += 1; + } + } +} + +inline static void InferOutputDims(const std::string& right, + const std::vector& broadcast_dims, + const LabelMap& labelshape, + std::vector* output_dims) { + for (int c : right) { + if (c == '.') { + output_dims->insert( + output_dims->end(), broadcast_dims.begin(), broadcast_dims.end()); + } else { + output_dims->push_back(labelshape[c]); + } + } +} +// +inline static void ParseEinsumEquation( + const std::string& equation, + const std::vector& inputs, + LabelMap* labelshape, + LabelMap* labeltype, + std::vector* all_labels, + std::vector* label2perms, + std::vector>* ellipsis_dims, + std::vector* broadcast_dims, + std::vector* output_dims, + std::string* right) { + auto results = paddle::string::split_string(equation, "->"); + auto left = results[0]; + ReplaceEllipsis(left); + *right = results[1].substr(1); + ReplaceEllipsis(*right); + auto op_labels = paddle::string::split_string(left, ","); + std::for_each(op_labels.begin(), op_labels.end(), ReplaceEllipsis); + GlobalInfo(op_labels, *right, labeltype, all_labels); + InferLabelShape(op_labels, inputs, labelshape, ellipsis_dims, broadcast_dims); + VLOG(5) << "Einsum Infershape: right:" << right; + VLOG(5) << "Einsum Infershape: op_labels:" + << paddle::string::join_strings(op_labels, "\n"); + InferOutputDims(*right, *broadcast_dims, *labelshape, output_dims); + for (size_t i = 0; i < inputs.size(); ++i) { + InferLabelPerm( + op_labels[i], ellipsis_dims->at(i).size(), &((*label2perms)[i])); + } +} + +inline void EinsumInferShape(const std::vector& inputs, + const std::string& equation, + MetaTensor* out) { + // collect the following informations to prepare einsum. + LabelMap labelshape(0); + LabelMap labeltype(LabelType::Reduction); + std::vector label2perms(inputs.size(), LabelMap(-1)); + std::vector all_labels; + std::vector broadcast_dims; + std::vector output_dims; + std::vector> ellipsis_dims(2); + + std::vector input_dims; + for (auto& i : inputs) { + input_dims.push_back(i->dims()); + } + std::string right; + ParseEinsumEquation(equation, + input_dims, + &labelshape, + &labeltype, + &all_labels, + &label2perms, + &ellipsis_dims, + &broadcast_dims, + &output_dims, + &right); + + VLOG(3) << "Einsum Infershape: input dims:" + << paddle::string::join_strings(input_dims, "\n"); + VLOG(3) << "Einsum Infershape: equation:" << equation; + VLOG(3) << "Einsum Infershape: all_labels:" + << paddle::string::join_strings(all_labels, ","); + VLOG(3) << "Einsum Infershape: output dims:" + << paddle::string::join_strings(output_dims, ","); + VLOG(3) << "Label Type is : " << label_to_string(all_labels, labeltype); + VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); +} + +template +std::vector GetLabelIndexByType(const std::vector& all_labels, + const LabelMap& type, + const LabelMap& perm, + const std::vector& ellipsis, + LabelType filter) { + std::vector res; + for (T c : all_labels) { + if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) { + if (c == '.') { + for (size_t i = 0; i < ellipsis.size(); ++i) res.push_back(perm[c] + i); + } else { + res.push_back(perm[c]); + } + } + } + return res; +} + +template +std::vector GetShapeByType(const std::vector& all_labels, + const LabelMap& type, + const LabelMap& perm, + const LabelMap& label2shape, + const std::vector& ellipsis, + LabelType filter) { + std::vector res; + for (T c : all_labels) { + if ((filter == LabelType::ALL_TYPE || type[c] == filter) && perm[c] != -1) { + if (c == '.') + res.insert(res.end(), ellipsis.begin(), ellipsis.end()); + else + res.push_back(label2shape[c]); + } + } + return res; +} + +template +DenseTensor PerformReduction(const Context& dev_ctx, + const DenseTensor& tensor, + const LabelMap& label2perm, + const std::vector& all_labels, + const std::vector& ellipsis, + const LabelMap& label2type) { + auto indices = GetLabelIndexByType( + all_labels, label2type, label2perm, ellipsis, LabelType::Reduction); + VLOG(5) << "call PerformReduction: with axis: " + << paddle::string::join_strings(indices, ","); + if (indices.size() == 0) return tensor; + return Sum(dev_ctx, tensor, indices, tensor.dtype(), true); +} + +template +DenseTensor PerformTranspose(const Context& dev_ctx, + const DenseTensor& tensor, + const LabelMap& label2perm, + const std::vector& all_labels, + const std::vector& ellipsis, + const LabelMap& label2type) { + auto is_no_need_transpose = [](std::vector& axis) { + for (size_t i = 0; i < axis.size(); ++i) { + if (i != size_t(axis[i])) return false; + } + return true; + }; + auto axis = GetLabelIndexByType( + all_labels, label2type, label2perm, ellipsis, LabelType::ALL_TYPE); + VLOG(5) << "PerformTranspose: " << paddle::string::join_strings(axis, ","); + if (is_no_need_transpose(axis)) { + return tensor; + } + auto ret = Transpose(dev_ctx, tensor, axis); + VLOG(5) << "PerformTranspose: do_transpose()"; + return ret; +} + +template +DenseTensor PerformContraction( + const Context& dev_ctx, + const DenseTensor& A, + const DenseTensor& B, + const std::vector& label2perm, + const std::vector& all_labels, + const LabelMap& label2type, + const LabelMap& label2shape, + const std::vector>& ellipsis_dims, + const std::vector& broadcast_dims) { + // Get All the Batches, so perm is + auto all_valid = LabelMap(1); + auto recover_dim = GetShapeByType(all_labels, + label2type, + all_valid, + label2shape, + broadcast_dims, + LabelType::Batch); + auto preprocess = [&](const DenseTensor& t, + const LabelMap& perm, + const std::vector& ellipsis) -> DenseTensor { + auto frees = GetShapeByType( + all_labels, label2type, perm, label2shape, ellipsis, LabelType::Free); + auto conts = GetShapeByType(all_labels, + label2type, + perm, + label2shape, + ellipsis, + LabelType::Contraction); + auto trans_t = PerformTranspose( + dev_ctx, t, perm, all_labels, ellipsis, label2type); + auto mul_dims = GetShapeByType( + all_labels, label2type, perm, label2shape, ellipsis, LabelType::Batch); + recover_dim.insert(recover_dim.end(), frees.begin(), frees.end()); + mul_dims.push_back( + std::accumulate(frees.begin(), frees.end(), 1, std::multiplies())); + mul_dims.push_back( + std::accumulate(conts.begin(), conts.end(), 1, std::multiplies())); + VLOG(5) << "PerformContraction: mul_dims: " + << paddle::string::join_strings(mul_dims, ","); + trans_t.Resize(make_ddim(mul_dims)); + return trans_t; + }; + auto trans_a = preprocess(A, label2perm[0], ellipsis_dims[0]); + auto trans_b = preprocess(B, label2perm[1], ellipsis_dims[1]); + auto after_contraction = + Matmul(dev_ctx, trans_a, trans_b, false, true); + VLOG(5) << "PerformContraction: recover_dim: " + << paddle::string::join_strings(recover_dim, ","); + after_contraction.Resize(make_ddim(recover_dim)); + return after_contraction; +} + +template +void TransposeToOutput(const Context& dev_ctx, + const DenseTensor& to_trans, + const std::string& right, + const std::vector& all_labels, + int n_broadcast_dims, + DenseTensor* output) { + std::vector axis; + int offset = 0; + if (std::find(all_labels.begin(), all_labels.end(), '.') != + all_labels.end()) { + offset = n_broadcast_dims - 1; + } + for (char c : right) { + if (c == '.') { + for (int i = 0; i < n_broadcast_dims; ++i) axis.push_back(i); + } else { + auto it = std::find(all_labels.begin(), all_labels.end(), c); + PADDLE_ENFORCE_NE(it, + all_labels.end(), + phi::errors::InvalidArgument("Must in all_labels.")); + axis.push_back(it - all_labels.begin() + offset); + } + } + VLOG(5) << "call TransposeToOutput: with axis: " + << paddle::string::join_strings(axis, ","); + if (axis.size() == 0) return output->ShareBufferWith(to_trans); + return TransposeKernel(dev_ctx, to_trans, axis, output); +} + +template +void EinsumKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out) { + ValidationCheck(equation); + // collect the following informations to prepare einsum. + LabelMap labelshape(0); + LabelMap labeltype(LabelType::Reduction); + std::vector label2perms(inputs.size(), LabelMap(-1)); + std::vector all_labels; // order: ABO, AO, BO, AB, Reduce + std::vector> ellipsis_dims(2); + std::vector broadcast_dims; + std::vector output_dims; + + std::vector input_dims; + for (auto& i : inputs) { + input_dims.push_back(i->dims()); + } + std::string right; + ParseEinsumEquation(equation, + input_dims, + &labelshape, + &labeltype, + &all_labels, + &label2perms, + &ellipsis_dims, + &broadcast_dims, + &output_dims, + &right); + out->Resize(make_ddim(output_dims)); + if (inputs.size() == 2) { + auto& A = inputs[0]; + auto& B = inputs[1]; + // Reduce Procedure + auto reduce_A = PerformReduction( + dev_ctx, *A, label2perms[0], all_labels, ellipsis_dims[0], labeltype); + auto reduce_B = PerformReduction( + dev_ctx, *B, label2perms[1], all_labels, ellipsis_dims[1], labeltype); + // Contract Procedure + dev_ctx.template Alloc(out); + auto after_contraction = PerformContraction(dev_ctx, + reduce_A, + reduce_B, + label2perms, + all_labels, + labeltype, + labelshape, + ellipsis_dims, + broadcast_dims); + TransposeToOutput(dev_ctx, + after_contraction, + right, + all_labels, + broadcast_dims.size(), + out); + // Reshape Procedure + } else if (inputs.size() == 1) { + auto reduce_A = PerformReduction(dev_ctx, + *inputs[0], + label2perms[0], + all_labels, + ellipsis_dims[0], + labeltype); + std::vector right_labels; + for (auto c : right) right_labels.push_back(c); + right_labels = union_labels(right_labels, all_labels); + *out = PerformTranspose(dev_ctx, + reduce_A, + label2perms[0], + right_labels, + broadcast_dims, + labeltype); + out->Resize(make_ddim(output_dims)); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "EinsumOp kernel only support len(operands) between (0, 2]. Use " + "opt_einsum first to convert multi-variable to binary-variable.")); + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b3cc3425df453799f18c13595eff188d69480da --- /dev/null +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); +} + +KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("einsum_grad", + {"Operands", {"Out@GRAD"}}, + {"equation"}, + {{"Operands@GRAD"}}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 32d8f5e3847c81bed286c7a0bda07c764312d5f7..c6111391b73b55950d7a8babd5b1f55ed58a0fc9 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1065,6 +1065,7 @@ set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_svd_op PROPERTIES TIMEOUT 80) +set_tests_properties(test_einsum_op PROPERTIES TIMEOUT 120) set_tests_properties(test_qr_op PROPERTIES TIMEOUT 60) set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_einsum_op.py b/python/paddle/fluid/tests/unittests/test_einsum_op.py new file mode 100644 index 0000000000000000000000000000000000000000..565e43214ea3237a592ba4afa9b793091344fade --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_einsum_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +from op_test import OpTest + + +class TestEinsumBinary(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "einsum" + self.disable = False + self.set_mandatory() + self.init_input() + np.random.seed(123) + out = np.einsum(self.equation, *self.inputs) + self.operands = [] + for idx, inp in enumerate(self.inputs): + self.operands.append(("x" + str(idx), inp)) + self.inputs = {"Operands": self.operands} + self.attrs = {"equation": self.equation} + self.outputs = {'Out': out} + + def init_input(self): + self.inputs = [] + for t, s in zip(self.types, self.shapes): + self.inputs.append(np.random.random(s).astype(t)) + + def set_mandatory(self): + self.disable = False + self.shapes = [(10, 10, 20), (20, 6)] + self.types = [np.float64, np.float64] + self.equation = "mij,jk->ki" + + def test_check_output(self): + if not self.disable: + self.check_output() + + def test_grad(self): + if not self.disable: + self.check_grad([op[0] for op in self.operands], ["Out"]) + + +class TestEinsum1(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(20, 3, 3), (20, 3, 3)] + self.types = [np.float64, np.float64] + self.equation = "mij,mjk->mik" + + +class TestEinsum2(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(20, 3, 3), (20, 3, 3)] + self.types = [np.float64, np.float64] + self.equation = "mij,mjk->ikm" + + +class TestEinsum3(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 10), (10, 10)] + self.types = [np.float64, np.float64] + self.equation = "ij,jk->ik" # }}} + + +class TestEinsumWithReduction(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 3, 5), (5, 30)] + self.types = [np.float64, np.float64] + self.equation = "ijk,kl->jl" + + +class TestEinsumWithReduction1(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 3, 3, 5), (10, 5, 10, 10)] + self.types = [np.float64, np.float64] + self.equation = "mijk,mklh->ljm" + + +class TestEinsumWithUnary(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 10, 3, 5)] + self.types = [np.float64] + self.equation = "mijk->mi" + + +class TestEinsumWithUnary1(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(5, 10, 3, 3), (3, 6, 3, 10)] + self.types = [np.float64, np.float64] + self.equation = "imjl,jklm->imk" + + +class TestEinsumWithBroadcast1(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(5, 10, 3, 3)] + self.types = [np.float64] + self.equation = "i...->..." + + +class TestEinsumWithBroadcast2(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 11), (3, 4, 5, 10)] + self.types = [np.float64, np.float64] + self.equation = "...ij,...i->j..." + + +class TestEinsumWithBroadcast3(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 3, 2, 3, 4), (12, 10)] + self.types = [np.float64, np.float64] + self.equation = "k...,...jk->...k" + + +class TestEinsumWithBroadcast4(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(10, 3, 2, 3, 4), (12, 10)] + self.types = [np.float64, np.float64] + self.equation = "a...d,...cb->...abcd" + + +class TestEinsumWithBroadcast5(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(3, 2, 2, 10), (10, 3, 2, 2)] + self.types = [np.float64, np.float64] + self.equation = "...a,a...->..." + + +class TestEinsumWithBroadcast6(TestEinsumBinary): + def set_mandatory(self): + self.shapes = [(100), (100)] + self.types = [np.float64, np.float64] + self.equation = "i,i->" + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/einsum.py b/python/paddle/tensor/einsum.py index 06c2a82fd696d8f84f6ce1b38fad95834d10d3a0..dd11477532d24daec0991cff1f4a99dd0d41641a 100644 --- a/python/paddle/tensor/einsum.py +++ b/python/paddle/tensor/einsum.py @@ -20,6 +20,10 @@ from .linalg import dot, matmul, transpose from .manipulation import squeeze, unsqueeze, reshape from .math import multiply from .math import sum as paddle_sum +from ..fluid.framework import _in_legacy_dygraph +from paddle import _C_ops +from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype +from ..fluid.layer_helper import LayerHelper from paddle.common_ops_import import dygraph_only @@ -660,6 +664,26 @@ def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast): return plan +def einsum_v2(equation, *operands): + if _in_legacy_dygraph(): + # dygraph + return _C_ops.einsum(operands, 'equation', equation) + # static graph + for inp in operands: + check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum') + check_type(equation, 'equation', str, 'einsum') + helper = LayerHelper('einsum', **locals()) + out = helper.create_variable_for_type_inference(dtype=operands[0].dtype) + attrs = dict() + attrs['equation'] = equation + helper.append_op( + type='einsum', + inputs={'Operands': operands}, + outputs={'Out': out}, + attrs=attrs, ) + return out + + def einsum(equation, *operands): r""" einsum(equation, *operands) @@ -817,6 +841,9 @@ def einsum(equation, *operands): # [0.50226176, 0.24512935, 0.39881429], # [0.51476848, 0.23367381, 0.39229113]]]) """ + import os + if int(os.environ.get('FLAGS_new_einsum', "0")): + return einsum_v2(equation, *operands) nop = len(operands) assert nop > 0, "At least one operand is expected."