未验证 提交 3e5744aa 编写于 作者: L Leo Chen 提交者: GitHub

Remove unused inputs for some operators (#22284)

* remove unused inputs, test=develop

* remove unused inputs, test=develop

* update dtype, test=develop

* remove unused inputs, test=develop

* update op_use_default_grad_op_maker, tese=develop

* resolve conflicts, test=develop

* follow comments, test=develop

* update center_loss_grad, test=develop
上级 51e147a1
......@@ -26,41 +26,35 @@ DEFINE_bool(enable_unused_var_check, false,
"Checking whether operator contains unused inputs, "
"especially for grad operator. It should be in unittest.");
// NOTE(zhiqiu): Currently, there are some operators which involves unused
// inputs and cannot be removed from the white_list below.
// They can be mainly divided into four categories:
// 0: the inputs of which are only used in if branch, or used in cuda kernel but
// not in cpu kernel;
// 1: the inputs of which are used to indicate dtype of outputs;
// 2: the inputs of which are used in fused operators.
// 3: specical operators, like ngraph_engine.
// The category number is presented in the comments after each operator.
const std::unordered_set<std::string> op_has_unsed_vars_white_list = {
"auc",
"batch_norm",
"batch_norm_grad",
"fused_batch_norm_act",
"fused_batch_norm_act_grad",
"sync_batch_norm_grad",
"center_loss_grad",
"crop",
"cvm",
"cos_sim_grad",
"dgc_momentum",
"fake_quantize_range_abs_max",
"fill_zeros_like",
"fusion_seqpool_cvm_concat",
"reshape2_grad_grad",
"reshape2_grad",
"gru_grad",
"hierarchical_sigmoid_grad",
"nce_grad",
"roi_perspective_transform_grad",
"sequence_conv_grad",
"gru_unit_grad",
"affine_grid_grad",
"fill_any_like",
"precision_recall",
"unsqueeze_grad",
"kldiv_loss_grad",
"cvm_grad",
"stack_grad",
"warpctc_grad",
"sync_batch_norm",
"match_matrix_tensor_grad",
"ngraph_engine",
"rmsprop"};
"batch_norm", // 0
"batch_norm_grad", // 0
"sync_batch_norm", // 0
"sync_batch_norm_grad", // 0
"dgc_momentum", // 0
"fake_quantize_range_abs_max", // 0
"rmsprop", // 0
"sequence_conv_grad", // 0
"roi_perspective_transform_grad", // 0
"fill_zeros_like", // 1
"fill_any_like", // 1
"nce_grad", // 1
"precision_recall", // 1
"fusion_seqpool_cvm_concat", // 2
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2
"ngraph_engine", // 3
};
namespace paddle {
namespace framework {
......
......@@ -191,9 +191,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output")),
ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......@@ -206,7 +207,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override {
auto* op = new T();
op->SetType("affine_grid_grad");
op->SetInput("Theta", this->Input("Theta"));
op->SetInput("OutputShape", this->Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
......
......@@ -141,6 +141,9 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
return retv;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CenterLossGradNoNeedBufVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -151,7 +154,8 @@ REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker,
ops::CenterLossOpGradMaker<paddle::framework::OpDesc>,
ops::CenterLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp);
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp,
ops::CenterLossGradNoNeedBufVarsInferer);
REGISTER_OP_CPU_KERNEL(center_loss, ops::CenterLossKernel<CPUCtx, float>,
ops::CenterLossKernel<CPUCtx, double>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cos_sim_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -165,14 +166,35 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
}
};
template <typename T>
class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
grad_op->SetType("cos_sim_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y"));
grad_op->SetInput("XNorm", this->Output("XNorm"));
grad_op->SetInput("YNorm", this->Output("YNorm"));
grad_op->SetInput("Out", this->Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
cos_sim, ops::CosSimOp, ops::CosSimOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(cos_sim, ops::CosSimOp, ops::CosSimOpMaker,
ops::CosSimGradOpMaker<paddle::framework::OpDesc>,
ops::CosSimGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL(
cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -201,13 +201,16 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GropNoNeedBufferVarInference, "Y");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
ops::CropGradOpMaker<paddle::framework::OpDesc>,
ops::CropGradOpMaker<paddle::imperative::OpBase>);
ops::CropGradOpMaker<paddle::imperative::OpBase>,
ops::GropNoNeedBufferVarInference);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -94,9 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Y")),
ctx.device_context());
}
};
......@@ -134,8 +134,8 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("cvm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("CVM", this->Input("CVM"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
......@@ -143,15 +143,20 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CVMNoNeedBufferVarInference, "CVM");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(CVMGradNoNeedBufferVarInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker,
ops::CVMGradOpMaker<paddle::framework::OpDesc>,
ops::CVMGradOpMaker<paddle::imperative::OpBase>);
ops::CVMGradOpMaker<paddle::imperative::OpBase>,
ops::CVMNoNeedBufferVarInference);
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp);
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp,
ops::CVMGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gru_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
......@@ -221,6 +222,13 @@ class GRUGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Hidden")),
ctx.device_context());
}
};
template <typename T>
......@@ -376,15 +384,53 @@ class GRUCPUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
grad_op->SetType("gru_grad");
grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput("H0", this->Input("H0"));
grad_op->SetInput("Bias", this->Input("Bias"));
grad_op->SetInput("Weight", this->Input("Weight"));
grad_op->SetInput("BatchGate", this->Output("BatchGate"));
grad_op->SetInput("BatchResetHiddenPrev",
this->Output("BatchResetHiddenPrev"));
grad_op->SetInput("BatchHidden", this->Output("BatchHidden"));
grad_op->SetInput("Hidden", this->Output("Hidden"));
grad_op->SetInput(framework::GradVarName("Hidden"),
this->OutputGrad("Hidden"));
grad_op->SetOutput(framework::GradVarName("H0"), this->InputGrad("H0"));
grad_op->SetOutput(framework::GradVarName("Input"),
this->InputGrad("Input"));
grad_op->SetOutput(framework::GradVarName("Weight"),
this->InputGrad("Weight"));
grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUGradOpNoNeedBufferVarInference,
"Input", "Bias");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
gru, ops::GRUOp, ops::GRUOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>)
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp);
REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
ops::GRUGradOpMaker<paddle::framework::OpDesc>,
ops::GRUGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp,
ops::GRUGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
ops::GRUCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -155,8 +155,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden");
......@@ -199,6 +197,13 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Hidden")),
ctx.device_context());
}
};
template <typename T>
......@@ -216,7 +221,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("Hidden", this->Output("Hidden"));
op->SetInput("Gate", this->Output("Gate"));
op->SetInput("ResetHiddenPrev", this->Output("ResetHiddenPrev"));
op->SetInput(framework::GradVarName("Hidden"), this->OutputGrad("Hidden"));
......@@ -232,6 +236,9 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GRUUnitGradOpNoNeedBufferVarInference,
"Bias");
} // namespace operators
} // namespace paddle
......@@ -240,7 +247,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
ops::GRUUnitGradOpMaker<paddle::framework::OpDesc>,
ops::GRUUnitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp,
ops::GRUUnitGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -290,6 +290,9 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
HierarchicalSigmoidGradOpNoNeedBufferVarInference, "Bias");
} // namespace operators
} // namespace paddle
......@@ -300,7 +303,8 @@ REGISTER_OPERATOR(
ops::HierarchicalSigmoidGradMaker<paddle::framework::OpDesc>,
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference);
ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -136,8 +136,9 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.GetPlace());
}
};
......@@ -161,6 +162,9 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(KLDivLossGradNoNeedBufferVarInference,
"X");
} // namespace operators
} // namespace paddle
......@@ -168,7 +172,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <vector>
#include "paddle/fluid/operators/match_matrix_tensor_op.h"
......@@ -313,6 +314,28 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class MatchMatrixTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
grad_op->SetType("match_matrix_tensor_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y"));
grad_op->SetInput("W", this->Input("W"));
grad_op->SetInput("Tmp", this->Output("Tmp"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
} // namespace operators
} // namespace paddle
......@@ -320,8 +343,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(
match_matrix_tensor, ops::MatchMatrixTensorOP,
ops::MatchMatrixTensorOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>)
ops::MatchMatrixTensorGradOpMaker<paddle::framework::OpDesc>,
ops::MatchMatrixTensorGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad);
REGISTER_OP_CPU_KERNEL(match_matrix_tensor,
......
......@@ -183,8 +183,10 @@ Output(AccumStatesInfo) is metrics of accumulation data.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(precision_recall, ops::PrecisionRecallOp,
ops::PrecisionRecallOpMaker);
REGISTER_OPERATOR(
precision_recall, ops::PrecisionRecallOp, ops::PrecisionRecallOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
precision_recall,
ops::PrecisionRecallKernel<paddle::platform::CPUPlace, float>,
......
......@@ -224,7 +224,6 @@ class NCEGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Label", this->Input("Label"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("Weight", this->Input("Weight"));
op->SetInput("Cost", this->Output("Cost"));
op->SetInput("SampleLogits", this->Output("SampleLogits"));
op->SetInput("SampleLabels", this->Output("SampleLabels"));
op->SetInput("SampleWeight", this->Input("SampleWeight"));
......@@ -247,7 +246,6 @@ class NCEOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasInput("Cost"));
PADDLE_ENFORCE(ctx->HasInput("SampleLogits"));
PADDLE_ENFORCE(ctx->HasInput("SampleLabels"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")),
......@@ -301,6 +299,9 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(NCEGradOpNoNeedBufferVarInference,
"Bias");
} // namespace operators
} // namespace paddle
......@@ -308,7 +309,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
ops::NCEGradOpMaker<paddle::framework::OpDesc>,
ops::NCEGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference,
ops::NCEGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad,
......
......@@ -438,9 +438,6 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> {
auto *grad_op = new T();
grad_op->SetType("reshape2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
if (this->HasInput("ShapeTensor")) {
grad_op->SetInput("ShapeTensor", this->Input("ShapeTensor"));
}
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
......@@ -456,11 +453,8 @@ class Reshape2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
grad_op->SetType("reshape2_grad_grad");
grad_op->SetInput("ShapeTensor", this->Input("ShapeTensor"));
grad_op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
grad_op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
......@@ -546,6 +540,8 @@ DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
ReshapeDoubleGradOpNoNeedBufferVarInference, "DOut");
} // namespace operators
} // namespace paddle
......@@ -576,7 +572,8 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInToOut);
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInToOut);
ops::ReshapeDoubleGradInplaceInToOut,
ops::ReshapeDoubleGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
......
......@@ -250,14 +250,13 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
auto use_stack = Attr<bool>("use_stack");
auto grad_op =
use_stack
? framework::OpRegistry::CreateOp(
"stack_grad", {{"X", names}, {"Y@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs)
: framework::OpRegistry::CreateOp(
"concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs);
auto grad_op = use_stack ? framework::OpRegistry::CreateOp(
"stack_grad", {{"Y@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs)
: framework::OpRegistry::CreateOp(
"concat_grad",
{{"X", names}, {"Out@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs);
grad_op->Run(scope, place);
......
......@@ -182,6 +182,29 @@ class UnsqueezeGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class UnsqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
grad_op->SetType("unsqueeze_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
......@@ -263,15 +286,17 @@ DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(UnsqueezeGradOpNoNeedBufferVarInference,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp);
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeGradOpNoNeedBufferVarInference);
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
......
......@@ -175,12 +175,15 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.device_context());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(WarpCTCGradOpNoNeedBufferVarInference,
"Logits");
} // namespace operators
} // namespace paddle
......@@ -188,7 +191,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>,
ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp,
ops::WarpCTCGradOpNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册