未验证 提交 48029ab0 编写于 作者: Z Zeng Jinle 提交者: GitHub

Remove some DefaultGradOpDescMaker (#20185)

* remove fc_grad, test=develop

* remove fsp op since no unittests, test=develop
上级 729f5846
conv_shift
cos_sim
fc
flatten
fsp
fused_embedding_seq_pool
gru
lrn
lstm_unit
......@@ -11,12 +8,10 @@ match_matrix_tensor
max_pool2d_with_index
max_pool3d_with_index
maxout
modified_huber_loss
nce
pool2d
pool3d
prelu
rank_loss
reduce_max
reduce_min
reduce_prod
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_shift_op.h"
#include <memory>
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
......@@ -191,12 +192,31 @@ class ConvShiftGradKernel<platform::CPUPlace, T>
}
}
};
class ConvShiftGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("conv_shift_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ConvShiftGradOpDescMaker);
REGISTER_OPERATOR(conv_shift_grad, ops::ConvShiftGradOp);
REGISTER_OP_CPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
......
......@@ -85,37 +85,6 @@ class FCOp : public framework::OperatorWithKernel {
}
};
void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), w_dims);
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true,
"Should have bias grad");
auto bias_dims = ctx->GetInputDim("Bias");
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
}
}
framework::OpKernelType FCOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
}
class FCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -154,8 +123,7 @@ The size of each dimension of the parameters checked in the infer-shape.
namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fc, ops::FCOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::FCOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -24,17 +24,6 @@ namespace operators {
using Tensor = framework::Tensor;
class FCOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
inline void FCOutputSize(const framework::DDim& in_dims,
const framework::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
......@@ -150,12 +151,30 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference
}
};
class FusedEmbeddingSeqPoolGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("fused_embedding_seq_pool_grad");
op->SetInput("Ids", Input("Ids"));
op->SetInput("W", Input("W"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("W"), InputGrad("W"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_embedding_seq_pool, ops::FusedEmbeddingSeqPoolOp,
paddle::framework::DefaultGradOpDescMaker<true>,
ops::FusedEmbeddingSeqPoolGradOpDescMaker,
ops::FusedEmbeddingSeqPoolOpMaker);
REGISTER_OPERATOR(fused_embedding_seq_pool_grad,
ops::FusedEmbeddingSeqPoolOpGrad,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/modified_huber_loss_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -86,38 +87,55 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"),
"Intermediate value must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) must not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
intermediate_dims, x_dims,
intermediate_dims, y_dims,
"The shape of X and intermediate value must be the same.");
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims,
PADDLE_ENFORCE_EQ(out_grad_dims, y_dims,
"The shape of Input(Out@Grad) and X must be the same.");
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
}
}
};
class ModifiedHuberLossGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("modified_huber_loss_grad");
op->SetInput("Y", Input("Y"));
op->SetInput("IntermediateVal", Output("IntermediateVal"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(modified_huber_loss, ops::ModifiedHuberLossOp,
ops::ModifiedHuberLossOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ModifiedHuberLossGradOpDescMaker);
REGISTER_OPERATOR(modified_huber_loss_grad, ops::ModifiedHuberLossGradOp);
REGISTER_OP_CPU_KERNEL(
......
......@@ -180,7 +180,7 @@ class RankLossGradDescMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(rank_loss, ops::RankLossOp, ops::RankLossOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::RankLossGradDescMaker);
REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp);
REGISTER_OP_CPU_KERNEL(
rank_loss, ops::RankLossKernel<paddle::platform::CPUDeviceContext, float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册