From 3e5744aa65ccda4e6a247c9627db1b1967f314dc Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 16 Jan 2020 19:42:43 +0800 Subject: [PATCH] Remove unused inputs for some operators (#22284) * remove unused inputs, test=develop * remove unused inputs, test=develop * update dtype, test=develop * remove unused inputs, test=develop * update op_use_default_grad_op_maker, tese=develop * resolve conflicts, test=develop * follow comments, test=develop * update center_loss_grad, test=develop --- paddle/fluid/framework/unused_var_check.cc | 62 +++++++++---------- paddle/fluid/operators/affine_grid_op.cc | 8 +-- paddle/fluid/operators/center_loss_op.cc | 6 +- paddle/fluid/operators/cos_sim_op.cc | 30 +++++++-- paddle/fluid/operators/crop_op.cc | 5 +- paddle/fluid/operators/cvm_op.cc | 17 +++-- paddle/fluid/operators/gru_op.cc | 56 +++++++++++++++-- paddle/fluid/operators/gru_unit_op.cc | 16 +++-- .../operators/hierarchical_sigmoid_op.cc | 6 +- paddle/fluid/operators/kldiv_loss_op.cc | 11 +++- .../fluid/operators/match_matrix_tensor_op.cc | 27 +++++++- .../operators/metrics/precision_recall_op.cc | 6 +- paddle/fluid/operators/nce_op.cc | 8 ++- paddle/fluid/operators/reshape_op.cc | 11 ++-- .../operators/tensor_array_to_tensor_op.cc | 15 +++-- paddle/fluid/operators/unsqueeze_op.cc | 35 +++++++++-- paddle/fluid/operators/warpctc_op.cc | 12 ++-- 17 files changed, 237 insertions(+), 94 deletions(-) diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index e1df5b4ee90..a220c79a088 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -26,41 +26,35 @@ DEFINE_bool(enable_unused_var_check, false, "Checking whether operator contains unused inputs, " "especially for grad operator. It should be in unittest."); +// NOTE(zhiqiu): Currently, there are some operators which involves unused +// inputs and cannot be removed from the white_list below. +// They can be mainly divided into four categories: +// 0: the inputs of which are only used in if branch, or used in cuda kernel but +// not in cpu kernel; +// 1: the inputs of which are used to indicate dtype of outputs; +// 2: the inputs of which are used in fused operators. +// 3: specical operators, like ngraph_engine. +// The category number is presented in the comments after each operator. + const std::unordered_set op_has_unsed_vars_white_list = { - "auc", - "batch_norm", - "batch_norm_grad", - "fused_batch_norm_act", - "fused_batch_norm_act_grad", - "sync_batch_norm_grad", - "center_loss_grad", - "crop", - "cvm", - "cos_sim_grad", - "dgc_momentum", - "fake_quantize_range_abs_max", - "fill_zeros_like", - "fusion_seqpool_cvm_concat", - "reshape2_grad_grad", - "reshape2_grad", - "gru_grad", - "hierarchical_sigmoid_grad", - "nce_grad", - "roi_perspective_transform_grad", - "sequence_conv_grad", - "gru_unit_grad", - "affine_grid_grad", - "fill_any_like", - "precision_recall", - "unsqueeze_grad", - "kldiv_loss_grad", - "cvm_grad", - "stack_grad", - "warpctc_grad", - "sync_batch_norm", - "match_matrix_tensor_grad", - "ngraph_engine", - "rmsprop"}; + "batch_norm", // 0 + "batch_norm_grad", // 0 + "sync_batch_norm", // 0 + "sync_batch_norm_grad", // 0 + "dgc_momentum", // 0 + "fake_quantize_range_abs_max", // 0 + "rmsprop", // 0 + "sequence_conv_grad", // 0 + "roi_perspective_transform_grad", // 0 + "fill_zeros_like", // 1 + "fill_any_like", // 1 + "nce_grad", // 1 + "precision_recall", // 1 + "fusion_seqpool_cvm_concat", // 2 + "fused_batch_norm_act", // 2 + "fused_batch_norm_act_grad", // 2 + "ngraph_engine", // 3 +}; namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 959af025c33..5915a24afc3 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -191,9 +191,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Output")), + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; @@ -206,7 +207,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker { std::unique_ptr Apply() const override { auto* op = new T(); op->SetType("affine_grid_grad"); - op->SetInput("Theta", this->Input("Theta")); op->SetInput("OutputShape", this->Input("OutputShape")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc index f0c0a5e619f..5b617230869 100644 --- a/paddle/fluid/operators/center_loss_op.cc +++ b/paddle/fluid/operators/center_loss_op.cc @@ -141,6 +141,9 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker { return retv; } }; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CenterLossGradNoNeedBufVarsInferer, "X"); + } // namespace operators } // namespace paddle @@ -151,7 +154,8 @@ REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker, ops::CenterLossOpGradMaker, ops::CenterLossOpGradMaker); -REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp); +REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp, + ops::CenterLossGradNoNeedBufVarsInferer); REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel, ops::CenterLossKernel); diff --git a/paddle/fluid/operators/cos_sim_op.cc b/paddle/fluid/operators/cos_sim_op.cc index aaabc59e25c..faa0164a088 100644 --- a/paddle/fluid/operators/cos_sim_op.cc +++ b/paddle/fluid/operators/cos_sim_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cos_sim_op.h" +#include namespace paddle { namespace operators { @@ -165,14 +166,35 @@ class CosSimOpGrad : public framework::OperatorWithKernel { } }; +template +class CosSimGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + auto* grad_op = new T(); + grad_op->SetType("cos_sim_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Y", this->Input("Y")); + grad_op->SetInput("XNorm", this->Output("XNorm")); + grad_op->SetInput("YNorm", this->Output("YNorm")); + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - cos_sim, ops::CosSimOp, ops::CosSimOpMaker, - paddle::framework::DefaultGradOpMaker, - paddle::framework::DefaultGradOpMaker); +REGISTER_OPERATOR(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, + ops::CosSimGradOpMaker, + ops::CosSimGradOpMaker); REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad); REGISTER_OP_CPU_KERNEL( cos_sim, ops::CosSimKernel); diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 6fa9f87346a..fc73f938a9b 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -201,13 +201,16 @@ class CropGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GropNoNeedBufferVarInference, "Y"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker, ops::CropGradOpMaker, - ops::CropGradOpMaker); + ops::CropGradOpMaker, + ops::GropNoNeedBufferVarInference); REGISTER_OPERATOR(crop_grad, ops::CropOpGrad); REGISTER_OP_CPU_KERNEL( crop, ops::CropKernel, diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 808671d96fc..62a81d7978f 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -94,9 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } }; @@ -134,8 +134,8 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker { std::unique_ptr Apply() const override { std::unique_ptr op(new T()); op->SetType("cvm_grad"); - op->SetInput("X", this->Input("X")); op->SetInput("CVM", this->Input("CVM")); + op->SetInput("X", this->Input("X")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetAttrMap(this->Attrs()); @@ -143,15 +143,20 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CVMNoNeedBufferVarInference, "CVM"); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CVMGradNoNeedBufferVarInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, ops::CVMGradOpMaker, - ops::CVMGradOpMaker); + ops::CVMGradOpMaker, + ops::CVMNoNeedBufferVarInference); -REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp); +REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp, + ops::CVMGradNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel, ops::CVMOpKernel); diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 4be1c3e5861..da413dba646 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/gru_op.h" +#include #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" @@ -221,6 +222,13 @@ class GRUGradOp : public framework::OperatorWithKernel { if (ctx->HasOutput(weight_grad_name)) ctx->SetOutputDim(weight_grad_name, weight_dims); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context()); + } }; template @@ -376,15 +384,53 @@ class GRUCPUKernel : public framework::OpKernel { } }; +template +class GRUGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + auto* grad_op = new T(); + grad_op->SetType("gru_grad"); + grad_op->SetInput("Input", this->Input("Input")); + grad_op->SetInput("H0", this->Input("H0")); + grad_op->SetInput("Bias", this->Input("Bias")); + grad_op->SetInput("Weight", this->Input("Weight")); + + grad_op->SetInput("BatchGate", this->Output("BatchGate")); + grad_op->SetInput("BatchResetHiddenPrev", + this->Output("BatchResetHiddenPrev")); + grad_op->SetInput("BatchHidden", this->Output("BatchHidden")); + grad_op->SetInput("Hidden", this->Output("Hidden")); + + grad_op->SetInput(framework::GradVarName("Hidden"), + this->OutputGrad("Hidden")); + + grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0")); + grad_op->SetOutput(framework::GradVarName("Input"), + this->InputGrad("Input")); + grad_op->SetOutput(framework::GradVarName("Weight"), + this->InputGrad("Weight")); + grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUGradOpNoNeedBufferVarInference, + "Input", "Bias"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - gru, ops::GRUOp, ops::GRUOpMaker, - paddle::framework::DefaultGradOpMaker, - paddle::framework::DefaultGradOpMaker) -REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); +REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker, + ops::GRUGradOpMaker, + ops::GRUGradOpMaker); +REGISTER_OPERATOR(gru_grad, ops::GRUGradOp, + ops::GRUGradOpNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel, ops::GRUCPUKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/gru_unit_op.cc b/paddle/fluid/operators/gru_unit_op.cc index ddb6588ab99..c5f7f7b3ff4 100644 --- a/paddle/fluid/operators/gru_unit_op.cc +++ b/paddle/fluid/operators/gru_unit_op.cc @@ -155,8 +155,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"), "Input(%s) of GRUUnitGradOp should not be null.", "ResetHiddenPrev"); - PADDLE_ENFORCE(ctx->HasInput("Hidden"), - "Input(%s) of GRUUnitGradOp should not be null.", "Hidden"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), "Input(%s@GRAD) of GRUUnitGradOp should not be null.", "Hidden"); @@ -199,6 +197,13 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { if (ctx->HasOutput(weight_grad_name)) ctx->SetOutputDim(weight_grad_name, weight_dims); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Hidden")), + ctx.device_context()); + } }; template @@ -216,7 +221,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Weight", this->Input("Weight")); op->SetInput("Bias", this->Input("Bias")); - op->SetInput("Hidden", this->Output("Hidden")); op->SetInput("Gate", this->Output("Gate")); op->SetInput("ResetHiddenPrev", this->Output("ResetHiddenPrev")); op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden")); @@ -232,6 +236,9 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUUnitGradOpNoNeedBufferVarInference, + "Bias"); + } // namespace operators } // namespace paddle @@ -240,7 +247,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, ops::GRUUnitGradOpMaker, ops::GRUUnitGradOpMaker); -REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp); +REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp, + ops::GRUUnitGradOpNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL( gru_unit, ops::GRUUnitKernel, diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 58e380183f1..8028b20e06d 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -290,6 +290,9 @@ class HierarchicalSigmoidGradOpGradVarTypeInference } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + HierarchicalSigmoidGradOpNoNeedBufferVarInference, "Bias"); + } // namespace operators } // namespace paddle @@ -300,7 +303,8 @@ REGISTER_OPERATOR( ops::HierarchicalSigmoidGradMaker, ops::HierarchicalSigmoidGradMaker); REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, - ops::HierarchicalSigmoidGradOpGradVarTypeInference); + ops::HierarchicalSigmoidGradOpGradVarTypeInference, + ops::HierarchicalSigmoidGradOpNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL( hierarchical_sigmoid, ops::HierarchicalSigmoidOpKernel, diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index 8492ac915b8..99f19b408b5 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -136,8 +136,9 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.GetPlace()); } }; @@ -161,6 +162,9 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(KLDivLossGradNoNeedBufferVarInference, + "X"); + } // namespace operators } // namespace paddle @@ -168,7 +172,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, ops::KLDivLossOpGradMaker, ops::KLDivLossOpGradMaker); -REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad); +REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad, + ops::KLDivLossGradNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL( kldiv_loss, ops::KLDivLossKernel, ops::KLDivLossKernel); diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index 4ecc0be7c44..e7382f17282 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/operators/match_matrix_tensor_op.h" @@ -313,6 +314,28 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel { } }; +template +class MatchMatrixTensorGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + auto* grad_op = new T(); + grad_op->SetType("match_matrix_tensor_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Y", this->Input("Y")); + grad_op->SetInput("W", this->Input("W")); + grad_op->SetInput("Tmp", this->Output("Tmp")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + grad_op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + } // namespace operators } // namespace paddle @@ -320,8 +343,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR( match_matrix_tensor, ops::MatchMatrixTensorOP, ops::MatchMatrixTensorOpMaker, - paddle::framework::DefaultGradOpMaker, - paddle::framework::DefaultGradOpMaker) + ops::MatchMatrixTensorGradOpMaker, + ops::MatchMatrixTensorGradOpMaker); REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad); REGISTER_OP_CPU_KERNEL(match_matrix_tensor, diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index 58b948b5a43..054f8c70cc2 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -183,8 +183,10 @@ Output(AccumStatesInfo) is metrics of accumulation data. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, - ops::PrecisionRecallOpMaker); +REGISTER_OPERATOR( + precision_recall, ops::PrecisionRecallOp, ops::PrecisionRecallOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( precision_recall, ops::PrecisionRecallKernel, diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index b6f68e3bee7..4fafe439edc 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -224,7 +224,6 @@ class NCEGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("Label", this->Input("Label")); op->SetInput("Bias", this->Input("Bias")); op->SetInput("Weight", this->Input("Weight")); - op->SetInput("Cost", this->Output("Cost")); op->SetInput("SampleLogits", this->Output("SampleLogits")); op->SetInput("SampleLabels", this->Output("SampleLabels")); op->SetInput("SampleWeight", this->Input("SampleWeight")); @@ -247,7 +246,6 @@ class NCEOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Weight")); - PADDLE_ENFORCE(ctx->HasInput("Cost")); PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), @@ -301,6 +299,9 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(NCEGradOpNoNeedBufferVarInference, + "Bias"); + } // namespace operators } // namespace paddle @@ -308,7 +309,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, ops::NCEGradOpMaker, ops::NCEGradOpMaker); -REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference); +REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference, + ops::NCEGradOpNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel, ops::NCEKernel); REGISTER_OP_CPU_KERNEL(nce_grad, diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 5b453e98fb2..df2f1ec7493 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -438,9 +438,6 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker { auto *grad_op = new T(); grad_op->SetType("reshape2_grad"); grad_op->SetInput("XShape", this->Output("XShape")); - if (this->HasInput("ShapeTensor")) { - grad_op->SetInput("ShapeTensor", this->Input("ShapeTensor")); - } grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); @@ -456,11 +453,8 @@ class Reshape2DoubleGradMaker : public framework::SingleGradOpMaker { std::unique_ptr Apply() const override { auto *grad_op = new T(); grad_op->SetType("reshape2_grad_grad"); - - grad_op->SetInput("ShapeTensor", this->Input("ShapeTensor")); grad_op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); grad_op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); - grad_op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); grad_op->SetAttrMap(this->Attrs()); return std::unique_ptr(grad_op); @@ -546,6 +540,8 @@ DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut, {framework::GradVarName("Out"), framework::GradVarName("X")}); DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"}); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE( + ReshapeDoubleGradOpNoNeedBufferVarInference, "DOut"); } // namespace operators } // namespace paddle @@ -576,7 +572,8 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ops::Reshape2DoubleGradMaker, ops::ReshapeGradInplaceInToOut); REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, - ops::ReshapeDoubleGradInplaceInToOut); + ops::ReshapeDoubleGradInplaceInToOut, + ops::ReshapeDoubleGradOpNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t, ops::ReshapeKernel, diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index 0c3cffa1a3e..bae493996cc 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -250,14 +250,13 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase { auto use_stack = Attr("use_stack"); - auto grad_op = - use_stack - ? framework::OpRegistry::CreateOp( - "stack_grad", {{"X", names}, {"Y@GRAD", {dout_name}}}, - {{"X@GRAD", grad_names}}, attrs) - : framework::OpRegistry::CreateOp( - "concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}}, - {{"X@GRAD", grad_names}}, attrs); + auto grad_op = use_stack ? framework::OpRegistry::CreateOp( + "stack_grad", {{"Y@GRAD", {dout_name}}}, + {{"X@GRAD", grad_names}}, attrs) + : framework::OpRegistry::CreateOp( + "concat_grad", + {{"X", names}, {"Out@GRAD", {dout_name}}}, + {{"X@GRAD", grad_names}}, attrs); grad_op->Run(scope, place); diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index b0a458cbb87..0a0f7af6d9e 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -182,6 +182,29 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->ShareLoD("X", framework::GradVarName("X")); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class UnsqueezeGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + std::unique_ptr Apply() const override { + auto *grad_op = new T(); + grad_op->SetType("unsqueeze_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } }; // FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on @@ -263,15 +286,17 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, {framework::GradVarName("Out"), framework::GradVarName("X")}); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(UnsqueezeGradOpNoNeedBufferVarInference, + "X"); } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, - paddle::framework::DefaultGradOpMaker, - paddle::framework::DefaultGradOpMaker); -REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); +REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + ops::UnsqueezeGradOpMaker, + ops::UnsqueezeGradOpMaker); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, + ops::UnsqueezeGradOpNoNeedBufferVarInference); REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, ops::Unsqueeze2GradOpMaker, diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index c6b7a6f2d35..04217e0ff20 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -175,12 +175,15 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context()); } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(WarpCTCGradOpNoNeedBufferVarInference, + "Logits"); + } // namespace operators } // namespace paddle @@ -188,7 +191,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, ops::WarpCTCGradOpMaker, ops::WarpCTCGradOpMaker); -REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp); +REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp, + ops::WarpCTCGradOpNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL( warpctc, ops::WarpCTCKernel); REGISTER_OP_CPU_KERNEL( -- GitLab