diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc deleted file mode 100644 index 458fc7afb9de2205cf036e17badec6c465309762..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/einsum_op.cc +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { -class EinsumOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class EinsumOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Operands", "(TensorList), The input tensor of einsum op.") - .AsDuplicable(); - AddOutput("Out", "(Tensor), The output tensor of einsum op."); - AddOutput( - "InnerCache", - "(Tensor), The cache of the forward transpose tensors: tA and tB.") - .AsDuplicable() - .AsExtra() - .AsIntermediate(); - - AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.") - .AsDuplicable() - .AsExtra() - .AsIntermediate(); - AddAttr("equation", - "(string) A einsum equation. such as `ij,jk->ik`" - "There must have `->` and the number of operands in " - "equation must equals the `Operands` length."); - AddComment(R"DOC( -Einsum Operator. - -This operator is used to perform einsum operation for given operands and equation. -)DOC"); - } -}; - -class EinsumGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - auto x_name = "Operands"; - auto x_grad_name = framework::GradVarName(x_name); - ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands")); - ctx->ShareAllLoD("Operands", x_grad_name); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto dtype = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(dtype, ctx.GetPlace()); - } -}; - -template -class EinsumGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr retv) const override { - retv->SetType("einsum_grad"); - if (this->HasOutput("InnerCache")) { - retv->SetInput("InnerCache", this->Output("InnerCache")); - } - if (this->HasOutput("XShape")) { - // add if for compatibility. - retv->SetInput("Operands", this->Output("XShape")); // for memory save. - } else { - retv->SetInput("Operands", this->Input("Operands")); - } - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetAttrMap(this->Attrs()); - retv->SetOutput(framework::GradVarName("Operands"), - this->InputGrad("Operands", false)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(einsum, - EinsumInferShapeFunctor, - PD_INFER_META(phi::EinsumRawInferMeta)); - -REGISTER_OPERATOR(einsum, - ops::EinsumOp, - ops::EinsumOpMaker, - EinsumInferShapeFunctor, - ops::EinsumGradMaker, - ops::EinsumGradMaker); - -REGISTER_OPERATOR(einsum_grad, ops::EinsumGradOp); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 9b564fecba1fbaf805dd7de30a27cacf3e777afb..4b1503d2356d86ee69c52d4a15cda4a6d8bca907 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -844,6 +844,16 @@ attrs : uplo : UPLO +- op : einsum + backward : einsum_grad + inputs : + x : Operands + outputs: + {out : Out, inner_cache: InnerCache, xshape : XShape} + drop_empty_grad: [x_grad] + extra: + outputs: [inner_cache, xshape] + - op : elementwise_pow backward : elementwise_pow_grad extra : diff --git a/paddle/phi/api/yaml/static_backward.yaml b/paddle/phi/api/yaml/static_backward.yaml index aa0efa9323ca73602d95053d849406dd4266744f..1b8c5482248528cb02dde1fc55524f1678f4b963 100755 --- a/paddle/phi/api/yaml/static_backward.yaml +++ b/paddle/phi/api/yaml/static_backward.yaml @@ -74,6 +74,17 @@ data_type : x optional : bias +- backward_op : einsum_grad + forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) + args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) + output : Tensor[](x_grad){x_shape.size()} + infer_meta : + func : UnchangedMultiInferMeta + param : [x_shape] + kernel : + func : einsum_grad + data_type : out_grad + - backward_op : embedding_grad forward : embedding (Tensor x, Tensor weight, int64_t padding_idx=-1) -> Tensor(out) args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1) diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index d7542a8bf671ce1b2fc43642a888f51832baa21d..df77a3f3aa9b4a106cf1f7ce450cffdeaf9dbd56 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -156,6 +156,17 @@ optional : bias backward : depthwise_conv2d_transpose_grad +- op : einsum + args : (Tensor[] x, str equation) + output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()} + infer_meta : + func : EinsumRawInferMeta + param : [x, equation] + kernel : + func : einsum + backward : einsum_grad + intermediate : inner_cache, xshape + - op : embedding args : (Tensor x, Tensor weight, int64_t padding_idx=-1) output : Tensor diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc deleted file mode 100644 index 3876a9b7c5766e2758df4363e7cadcc5557c728e..0000000000000000000000000000000000000000 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature EinsumOpArgumentMapping( - const ArgumentMappingContext& ctx UNUSED) { - return KernelSignature( - "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"}); -} - -KernelSignature EinsumGradOpArgumentMapping( - const ArgumentMappingContext& ctx UNUSED) { - return KernelSignature("einsum_grad", - {"Operands", "InnerCache", "Out@GRAD"}, - {"equation"}, - {"Operands@GRAD"}); -} -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping);