未验证 提交 3e5744aa 编写于 作者: L Leo Chen 提交者: GitHub

Remove unused inputs for some operators (#22284)

* remove unused inputs, test=develop

* remove unused inputs, test=develop

* update dtype, test=develop

* remove unused inputs, test=develop

* update op_use_default_grad_op_maker, tese=develop

* resolve conflicts, test=develop

* follow comments, test=develop

* update center_loss_grad, test=develop
上级 51e147a1
...@@ -26,41 +26,35 @@ DEFINE_bool(enable_unused_var_check, false, ...@@ -26,41 +26,35 @@ DEFINE_bool(enable_unused_var_check, false,
"Checking whether operator contains unused inputs, " "Checking whether operator contains unused inputs, "
"especially for grad operator. It should be in unittest."); "especially for grad operator. It should be in unittest.");
// NOTE(zhiqiu): Currently, there are some operators which involves unused
// inputs and cannot be removed from the white_list below.
// They can be mainly divided into four categories:
// 0: the inputs of which are only used in if branch, or used in cuda kernel but
// not in cpu kernel;
// 1: the inputs of which are used to indicate dtype of outputs;
// 2: the inputs of which are used in fused operators.
// 3: specical operators, like ngraph_engine.
// The category number is presented in the comments after each operator.
const std::unordered_set<std::string> op_has_unsed_vars_white_list = { const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"auc", "batch_norm", // 0
"batch_norm", "batch_norm_grad", // 0
"batch_norm_grad", "sync_batch_norm", // 0
"fused_batch_norm_act", "sync_batch_norm_grad", // 0
"fused_batch_norm_act_grad", "dgc_momentum", // 0
"sync_batch_norm_grad", "fake_quantize_range_abs_max", // 0
"center_loss_grad", "rmsprop", // 0
"crop", "sequence_conv_grad", // 0
"cvm", "roi_perspective_transform_grad", // 0
"cos_sim_grad", "fill_zeros_like", // 1
"dgc_momentum", "fill_any_like", // 1
"fake_quantize_range_abs_max", "nce_grad", // 1
"fill_zeros_like", "precision_recall", // 1
"fusion_seqpool_cvm_concat", "fusion_seqpool_cvm_concat", // 2
"reshape2_grad_grad", "fused_batch_norm_act", // 2
"reshape2_grad", "fused_batch_norm_act_grad", // 2
"gru_grad", "ngraph_engine", // 3
"hierarchical_sigmoid_grad", };
"nce_grad",
"roi_perspective_transform_grad",
"sequence_conv_grad",
"gru_unit_grad",
"affine_grid_grad",
"fill_any_like",
"precision_recall",
"unsqueeze_grad",
"kldiv_loss_grad",
"cvm_grad",
"stack_grad",
"warpctc_grad",
"sync_batch_norm",
"match_matrix_tensor_grad",
"ngraph_engine",
"rmsprop"};
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -191,9 +191,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -191,9 +191,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(), ctx, framework::GradVarName("Output")),
framework::DataLayout::kAnyLayout, library_); ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
} }
}; };
...@@ -206,7 +207,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -206,7 +207,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new T(); auto* op = new T();
op->SetType("affine_grid_grad"); op->SetType("affine_grid_grad");
op->SetInput("Theta", this->Input("Theta"));
op->SetInput("OutputShape", this->Input("OutputShape")); op->SetInput("OutputShape", this->Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
......
...@@ -141,6 +141,9 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -141,6 +141,9 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
return retv; return retv;
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CenterLossGradNoNeedBufVarsInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -151,7 +154,8 @@ REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker, ...@@ -151,7 +154,8 @@ REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker,
ops::CenterLossOpGradMaker<paddle::framework::OpDesc>, ops::CenterLossOpGradMaker<paddle::framework::OpDesc>,
ops::CenterLossOpGradMaker<paddle::imperative::OpBase>); ops::CenterLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp); REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp,
ops::CenterLossGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel<CPUCtx, float>,
ops::CenterLossKernel<CPUCtx, double>); ops::CenterLossKernel<CPUCtx, double>);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cos_sim_op.h" #include "paddle/fluid/operators/cos_sim_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -165,14 +166,35 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -165,14 +166,35 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
grad_op->SetType("cos_sim_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y"));
grad_op->SetInput("XNorm", this->Output("XNorm"));
grad_op->SetInput("YNorm", this->Output("YNorm"));
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(cos_sim, ops::CosSimOp, ops::CosSimOpMaker,
cos_sim, ops::CosSimOp, ops::CosSimOpMaker, ops::CosSimGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::CosSimGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad); REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>); cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -201,13 +201,16 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -201,13 +201,16 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GropNoNeedBufferVarInference, "Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker, REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
ops::CropGradOpMaker<paddle::framework::OpDesc>, ops::CropGradOpMaker<paddle::framework::OpDesc>,
ops::CropGradOpMaker<paddle::imperative::OpBase>); ops::CropGradOpMaker<paddle::imperative::OpBase>,
ops::GropNoNeedBufferVarInference);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad); REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>, crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -94,9 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel { ...@@ -94,9 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx, framework::GradVarName("Y")),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -134,8 +134,8 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -134,8 +134,8 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T()); std::unique_ptr<T> op(new T());
op->SetType("cvm_grad"); op->SetType("cvm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("CVM", this->Input("CVM")); op->SetInput("CVM", this->Input("CVM"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
...@@ -143,15 +143,20 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -143,15 +143,20 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CVMNoNeedBufferVarInference, "CVM");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CVMGradNoNeedBufferVarInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker,
ops::CVMGradOpMaker<paddle::framework::OpDesc>, ops::CVMGradOpMaker<paddle::framework::OpDesc>,
ops::CVMGradOpMaker<paddle::imperative::OpBase>); ops::CVMGradOpMaker<paddle::imperative::OpBase>,
ops::CVMNoNeedBufferVarInference);
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp); REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp,
ops::CVMGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>); REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h" #include "paddle/fluid/operators/gru_op.h"
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
...@@ -221,6 +222,13 @@ class GRUGradOp : public framework::OperatorWithKernel { ...@@ -221,6 +222,13 @@ class GRUGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(weight_grad_name)) if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims); ctx->SetOutputDim(weight_grad_name, weight_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Hidden")),
ctx.device_context());
}
}; };
template <typename T> template <typename T>
...@@ -376,15 +384,53 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -376,15 +384,53 @@ class GRUCPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
grad_op->SetType("gru_grad");
grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput("H0", this->Input("H0"));
grad_op->SetInput("Bias", this->Input("Bias"));
grad_op->SetInput("Weight", this->Input("Weight"));
grad_op->SetInput("BatchGate", this->Output("BatchGate"));
grad_op->SetInput("BatchResetHiddenPrev",
this->Output("BatchResetHiddenPrev"));
grad_op->SetInput("BatchHidden", this->Output("BatchHidden"));
grad_op->SetInput("Hidden", this->Output("Hidden"));
grad_op->SetInput(framework::GradVarName("Hidden"),
this->OutputGrad("Hidden"));
grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0"));
grad_op->SetOutput(framework::GradVarName("Input"),
this->InputGrad("Input"));
grad_op->SetOutput(framework::GradVarName("Weight"),
this->InputGrad("Weight"));
grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUGradOpNoNeedBufferVarInference,
"Input", "Bias");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
gru, ops::GRUOp, ops::GRUOpMaker, ops::GRUGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::GRUGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>) REGISTER_OPERATOR(gru_grad, ops::GRUGradOp,
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp); ops::GRUGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>, REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
ops::GRUCPUKernel<double>); ops::GRUCPUKernel<double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -155,8 +155,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -155,8 +155,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"), PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.", "Input(%s) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev"); "ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.", "Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden"); "Hidden");
...@@ -199,6 +197,13 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -199,6 +197,13 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(weight_grad_name)) if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims); ctx->SetOutputDim(weight_grad_name, weight_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Hidden")),
ctx.device_context());
}
}; };
template <typename T> template <typename T>
...@@ -216,7 +221,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -216,7 +221,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Weight", this->Input("Weight")); op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Bias", this->Input("Bias")); op->SetInput("Bias", this->Input("Bias"));
op->SetInput("Hidden", this->Output("Hidden"));
op->SetInput("Gate", this->Output("Gate")); op->SetInput("Gate", this->Output("Gate"));
op->SetInput("ResetHiddenPrev", this->Output("ResetHiddenPrev")); op->SetInput("ResetHiddenPrev", this->Output("ResetHiddenPrev"));
op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden")); op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden"));
...@@ -232,6 +236,9 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -232,6 +236,9 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUUnitGradOpNoNeedBufferVarInference,
"Bias");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -240,7 +247,8 @@ namespace ops = paddle::operators; ...@@ -240,7 +247,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>, ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,
ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>); ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp); REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp,
ops::GRUUnitGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>, gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -290,6 +290,9 @@ class HierarchicalSigmoidGradOpGradVarTypeInference ...@@ -290,6 +290,9 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
HierarchicalSigmoidGradOpNoNeedBufferVarInference, "Bias");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -300,7 +303,8 @@ REGISTER_OPERATOR( ...@@ -300,7 +303,8 @@ REGISTER_OPERATOR(
ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>, ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>,
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>); ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp, REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference); ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid, hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>, ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -136,8 +136,9 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { ...@@ -136,8 +136,9 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); ctx, framework::GradVarName("Loss")),
ctx.GetPlace());
} }
}; };
...@@ -161,6 +162,9 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -161,6 +162,9 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(KLDivLossGradNoNeedBufferVarInference,
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -168,7 +172,8 @@ namespace ops = paddle::operators; ...@@ -168,7 +172,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>, ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>); ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad); REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>, kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>); ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/operators/match_matrix_tensor_op.h" #include "paddle/fluid/operators/match_matrix_tensor_op.h"
...@@ -313,6 +314,28 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> { ...@@ -313,6 +314,28 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class MatchMatrixTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
grad_op->SetType("match_matrix_tensor_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y"));
grad_op->SetInput("W", this->Input("W"));
grad_op->SetInput("Tmp", this->Output("Tmp"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -320,8 +343,8 @@ namespace ops = paddle::operators; ...@@ -320,8 +343,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
match_matrix_tensor, ops::MatchMatrixTensorOP, match_matrix_tensor, ops::MatchMatrixTensorOP,
ops::MatchMatrixTensorOpMaker, ops::MatchMatrixTensorOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::MatchMatrixTensorGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>) ops::MatchMatrixTensorGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad); REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad);
REGISTER_OP_CPU_KERNEL(match_matrix_tensor, REGISTER_OP_CPU_KERNEL(match_matrix_tensor,
......
...@@ -183,8 +183,10 @@ Output(AccumStatesInfo) is metrics of accumulation data. ...@@ -183,8 +183,10 @@ Output(AccumStatesInfo) is metrics of accumulation data.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp, REGISTER_OPERATOR(
ops::PrecisionRecallOpMaker); precision_recall, ops::PrecisionRecallOp, ops::PrecisionRecallOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
precision_recall, precision_recall,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>, ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
......
...@@ -224,7 +224,6 @@ class NCEGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -224,7 +224,6 @@ class NCEGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Label", this->Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput("Bias", this->Input("Bias")); op->SetInput("Bias", this->Input("Bias"));
op->SetInput("Weight", this->Input("Weight")); op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Cost", this->Output("Cost"));
op->SetInput("SampleLogits", this->Output("SampleLogits")); op->SetInput("SampleLogits", this->Output("SampleLogits"));
op->SetInput("SampleLabels", this->Output("SampleLabels")); op->SetInput("SampleLabels", this->Output("SampleLabels"));
op->SetInput("SampleWeight", this->Input("SampleWeight")); op->SetInput("SampleWeight", this->Input("SampleWeight"));
...@@ -247,7 +246,6 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -247,7 +246,6 @@ class NCEOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasInput("Cost"));
PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); PADDLE_ENFORCE(ctx->HasInput("SampleLogits"));
PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); PADDLE_ENFORCE(ctx->HasInput("SampleLabels"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")),
...@@ -301,6 +299,9 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference { ...@@ -301,6 +299,9 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(NCEGradOpNoNeedBufferVarInference,
"Bias");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -308,7 +309,8 @@ namespace ops = paddle::operators; ...@@ -308,7 +309,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
ops::NCEGradOpMaker<paddle::framework::OpDesc>, ops::NCEGradOpMaker<paddle::framework::OpDesc>,
ops::NCEGradOpMaker<paddle::imperative::OpBase>); ops::NCEGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference); REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference,
ops::NCEGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>, REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>); ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad, REGISTER_OP_CPU_KERNEL(nce_grad,
......
...@@ -438,9 +438,6 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> { ...@@ -438,9 +438,6 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> {
auto *grad_op = new T(); auto *grad_op = new T();
grad_op->SetType("reshape2_grad"); grad_op->SetType("reshape2_grad");
grad_op->SetInput("XShape", this->Output("XShape")); grad_op->SetInput("XShape", this->Output("XShape"));
if (this->HasInput("ShapeTensor")) {
grad_op->SetInput("ShapeTensor", this->Input("ShapeTensor"));
}
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
...@@ -456,11 +453,8 @@ class Reshape2DoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -456,11 +453,8 @@ class Reshape2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *grad_op = new T(); auto *grad_op = new T();
grad_op->SetType("reshape2_grad_grad"); grad_op->SetType("reshape2_grad_grad");
grad_op->SetInput("ShapeTensor", this->Input("ShapeTensor"));
grad_op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); grad_op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
grad_op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); grad_op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); grad_op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op); return std::unique_ptr<T>(grad_op);
...@@ -546,6 +540,8 @@ DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut, ...@@ -546,6 +540,8 @@ DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"}); DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
ReshapeDoubleGradOpNoNeedBufferVarInference, "DOut");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -576,7 +572,8 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, ...@@ -576,7 +572,8 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>, ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInToOut); ops::ReshapeGradInplaceInToOut);
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInToOut); ops::ReshapeDoubleGradInplaceInToOut,
ops::ReshapeDoubleGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int8_t, ops::ReshapeKernel, ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
......
...@@ -250,14 +250,13 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase { ...@@ -250,14 +250,13 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
auto use_stack = Attr<bool>("use_stack"); auto use_stack = Attr<bool>("use_stack");
auto grad_op = auto grad_op = use_stack ? framework::OpRegistry::CreateOp(
use_stack "stack_grad", {{"Y@GRAD", {dout_name}}},
? framework::OpRegistry::CreateOp( {{"X@GRAD", grad_names}}, attrs)
"stack_grad", {{"X", names}, {"Y@GRAD", {dout_name}}}, : framework::OpRegistry::CreateOp(
{{"X@GRAD", grad_names}}, attrs) "concat_grad",
: framework::OpRegistry::CreateOp( {{"X", names}, {"Out@GRAD", {dout_name}}},
"concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}}, {{"X@GRAD", grad_names}}, attrs);
{{"X@GRAD", grad_names}}, attrs);
grad_op->Run(scope, place); grad_op->Run(scope, place);
......
...@@ -182,6 +182,29 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel { ...@@ -182,6 +182,29 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class UnsqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
grad_op->SetType("unsqueeze_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
}; };
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on // FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
...@@ -263,15 +286,17 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); ...@@ -263,15 +286,17 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(UnsqueezeGradOpNoNeedBufferVarInference,
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>); REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); ops::UnsqueezeGradOpNoNeedBufferVarInference);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>, ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
......
...@@ -175,12 +175,15 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -175,12 +175,15 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), ctx, framework::GradVarName("Loss")),
ctx.device_context()); ctx.device_context());
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(WarpCTCGradOpNoNeedBufferVarInference,
"Logits");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -188,7 +191,8 @@ namespace ops = paddle::operators; ...@@ -188,7 +191,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>, ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>,
ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>); ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp); REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp,
ops::WarpCTCGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>); warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册