diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc deleted file mode 100644 index 0c628a46518b5704d065b598334852ec4251756d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/mean_op.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class MeanOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class MeanOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) The input of mean op"); - AddOutput("Out", "(Tensor) The output of mean op"); - AddComment(R"DOC( -Mean Operator calculates the mean of all elements in X. - -)DOC"); - } -}; - -class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -class MeanGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", framework::GradVarName("X")); - } - - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -template -class MeanGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("mean_grad"); - grad_op->SetInput("X", this->Input("X")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(mean, - MeanInferShapeFunctor, - PD_INFER_META(phi::MeanAllInferMeta)); -REGISTER_OPERATOR(mean, - ops::MeanOp, - ops::MeanOpMaker, - ops::MeanOpInferVarType, - ops::MeanGradMaker, - ops::MeanGradMaker, - MeanInferShapeFunctor); - -REGISTER_OPERATOR(mean_grad, - ops::MeanGradOp, - ops::MeanGradNoNeedBufferVarsInferer); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 6ffb733bd7a6abf18c547efda6d4cf77695ad129..0a3384d13fc684030c8c1862cab6aaa248c576cb 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1102,6 +1102,18 @@ kernel : func : maxout_grad +- backward_op : mean_all_grad + forward : mean_all(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedExceptLayoutInferMeta + param: [x] + kernel : + func : mean_all_grad + data_type: out_grad + no_need_buffer : x + - backward_op : memory_efficient_attention_grad forward : memory_efficient_attention (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test) -> Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset) args : (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor output, Tensor logsumexp, Tensor seed_and_offset, Tensor output_grad, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 2d4fcfb83ca1ba1cdfb2b93868a04f33ccf317ce..d9533f898955d80ee451dca9c657c640e66e0428 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -586,16 +586,6 @@ func : maximum_grad composite : maximum_grad(x, y, out_grad, axis, x_grad, y_grad) -- backward_op : mean_all_grad - forward : mean_all(Tensor x) -> Tensor(out) - args : (Tensor x, Tensor out_grad) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param: [x] - kernel : - func : mean_all_grad - - backward_op : mean_double_grad forward: mean_grad (Tensor x, Tensor grad_out, IntArray axis={}, bool keepdim=false, bool reduce_all = false) -> Tensor(grad_x) args : (Tensor grad_x_grad, IntArray axis={}, bool keepdim=false) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 2be1552e4134e6097866dacc55805f48175cfc90..51647c8f096ab3c204a0169e8fd6797061863005 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -816,15 +816,6 @@ func : mean backward : mean_grad -- op : mean_all - args : (Tensor x) - output : Tensor - infer_meta : - func : MeanAllInferMeta - kernel : - func : mean_all - backward : mean_all_grad - - op : merged_adam_ args : (Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1, Scalar beta2, Scalar epsilon, bool multi_precision, bool use_global_beta_pow) output : Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()} diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 2fad03cf671e60be12382d0e105c1bc86d7ffc4d..acf449c1ed388c34546a204296e1cdb52d0c3240 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1441,6 +1441,13 @@ extra : attrs : [bool use_mkldnn = false] +- op : mean_all (mean) + backward : mean_all_grad (mean_grad) + inputs : + x : X + outputs : + out : Out + - op : merge_selected_rows inputs : x : X diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 45b2be32a582c7f8a5ae988bdf66d4d6c3101440..3c344060b9bcbf879c2d84f52c583e38c04128fb 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1227,6 +1227,15 @@ func : maxout backward : maxout_grad +- op : mean_all + args : (Tensor x) + output : Tensor + infer_meta : + func : MeanAllInferMeta + kernel : + func : mean_all + backward : mean_all_grad + - op : memory_efficient_attention args : (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test) output : Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 722a349a53d41e729af7b1e74c88a3a51c138cb5..7d3d89434f3ef533b435d2f932be10fff64cb95a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4436,6 +4436,13 @@ void TriuInferMeta(const MetaTensor& x, int diagonal, MetaTensor* out) { TrilTriuInferMeta(x, diagonal, false, out); } +// Some operator having oneDnn kernel will be set layout in kernel. +void UnchangedExceptLayoutInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out) { out->share_meta(x); } diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 7cf9056f84d9ce35c799d1bb44c07445b71ec1a1..297e6d5648d0aa323850969ccde7b80162d135d1 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -628,6 +628,8 @@ void UnbindInferMeta(const MetaTensor& x, int axis, std::vector outs); +void UnchangedExceptLayoutInferMeta(const MetaTensor& x, MetaTensor* out); + void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out); // meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1] diff --git a/paddle/phi/ops/compat/mean_sig.cc b/paddle/phi/ops/compat/mean_sig.cc deleted file mode 100644 index 461d6ab32cec4cb3580a37e1b86ef557d31a1b72..0000000000000000000000000000000000000000 --- a/paddle/phi/ops/compat/mean_sig.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature MeanOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("mean_all", {"X"}, {}, {"Out"}); -} - -KernelSignature MeanGradOpGradArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("mean_all_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(mean, mean_all); -PD_REGISTER_BASE_KERNEL_NAME(mean_grad, mean_all_grad); - -PD_REGISTER_ARG_MAPPING_FN(mean, phi::MeanOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(mean_grad, phi::MeanGradOpGradArgumentMapping);