diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index 2728aa8a4ee21a9e1fe3deddcdba4c35a6aba7bc..f31c177c92d0a9e4cc731c478ea8339b450f318a 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/lstmp_op.h" +#include #include namespace paddle { @@ -45,6 +46,7 @@ class LSTMPOp : public framework::OperatorWithKernel { "Output(BatchHidden) of LSTMP operator should not be null."); auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank of LSTMP operator must be 2."); @@ -269,13 +271,47 @@ Users can choose to use fully-connected operator before LSTMP operator. } }; +class LSTMPGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* grad_op = new framework::OpDesc(); + grad_op->SetType("lstmp_grad"); + grad_op->SetInput("Weight", Input("Weight")); + grad_op->SetInput("ProjWeight", Input("ProjWeight")); + grad_op->SetInput("Bias", Input("Bias")); + + grad_op->SetInput("Projection", Output("Projection")); + grad_op->SetInput("Cell", Output("Cell")); + grad_op->SetInput("BatchGate", Output("BatchGate")); + grad_op->SetInput("BatchCellPreAct", Output("BatchCellPreAct")); + grad_op->SetInput("BatchHidden", Output("BatchHidden")); + grad_op->SetInput("H0", Input("H0")); + grad_op->SetInput("C0", Input("C0")); + + grad_op->SetInput(framework::GradVarName("Projection"), + OutputGrad("Projection")); + + grad_op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + grad_op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); + grad_op->SetOutput(framework::GradVarName("ProjWeight"), + InputGrad("ProjWeight")); + grad_op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + grad_op->SetOutput(framework::GradVarName("H0"), InputGrad("H0")); + grad_op->SetOutput(framework::GradVarName("C0"), InputGrad("C0")); + + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + class LSTMPGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Projection"), "Input(Projection) of LSTMP operator should not be null."); PADDLE_ENFORCE(ctx->HasInput("Cell"), @@ -298,7 +334,8 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); }; - SetOutGradDim("Input"); + ctx->SetOutputDim(framework::GradVarName("Input"), + ctx->GetInputDim("BatchGate")); SetOutGradDim("Weight"); SetOutGradDim("ProjWeight"); SetOutGradDim("Bias"); @@ -310,7 +347,8 @@ class LSTMPGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + ctx.Input("BatchGate")->type(), + ctx.device_context()); } }; @@ -318,8 +356,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(lstmp, ops::LSTMPOp, ops::LSTMPOpMaker, - paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(lstmp, ops::LSTMPOp, ops::LSTMPOpMaker, ops::LSTMPGradMaker); REGISTER_OPERATOR(lstmp_grad, ops::LSTMPGradOp); REGISTER_OP_CPU_KERNEL( lstmp, ops::LSTMPKernel, diff --git a/paddle/fluid/operators/lstmp_op.h b/paddle/fluid/operators/lstmp_op.h index c7d6e4205f8862526904e4fa767a2f4c4a2d8481..36da882639a235f27b4e5a9e77bf0813ea9c0ee3 100644 --- a/paddle/fluid/operators/lstmp_op.h +++ b/paddle/fluid/operators/lstmp_op.h @@ -267,7 +267,6 @@ class LSTMPGradKernel : public framework::OpKernel { } void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); auto* weight = ctx.Input("Weight"); auto* proj_weight = ctx.Input("ProjWeight"); auto* bias = ctx.Input("Bias"); @@ -323,7 +322,8 @@ class LSTMPGradKernel : public framework::OpKernel { ordered_c0_g.mutable_data(c0_g->dims(), ctx.GetPlace()); } - auto in_dims = input->dims(); + // batch_gate dims equal to input dims + auto in_dims = batch_gate->dims(); auto out_dims = cell_out->dims(); framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]}); int frame_size = static_cast(in_dims[1] / 4); diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index a7f7fb26b17c77e6fe87646d3cac20c02c49b52c..bc8fcf26ce461e6d57c756c6f8bc326ce2939814 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/sample_logits_op.h" +#include #include "paddle/fluid/operators/math/sample_prob.h" namespace paddle { @@ -60,6 +61,10 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor, default: Tensor), A 2-D tensor with shape [N, NT + S]." "The probabilites of sampled positive and negtive labels.") .AsIntermediate(); + AddOutput("LogitsDim", "Store dim information of Logits for gradient op") + .AsIntermediate(); + AddOutput("LabelsDim", "Store dim information of Logits for gradient op") + .AsIntermediate(); AddOutput("SampledLogits", "(Tensor, default: Tensor), A 2-D tensor with shape" "[N, NT + S]. The outputs value of sampled logits, which will be" @@ -121,6 +126,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel { "Output(SampledLogits) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"), "Output(SampledLabels) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("LogitsDim"), + "Output(LogitsDim) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("LabelsDim"), + "Output(LabelsDim) should be not null."); auto logits_dims = ctx->GetInputDim("Logits"); auto labels_dims = ctx->GetInputDim("Labels"); @@ -137,6 +146,15 @@ class SampleLogitsOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]}); + + // append 0 to shape variable to avoid optimized by memory optimize pass + auto logits_dim_vec = framework::vectorize(logits_dims); + logits_dim_vec.push_back(0); + ctx->SetOutputDim("LogitsDim", framework::make_ddim(logits_dim_vec)); + + auto labels_dim_vec = framework::vectorize(labels_dims); + labels_dim_vec.push_back(0); + ctx->SetOutputDim("LabelsDim", framework::make_ddim(labels_dim_vec)); } protected: @@ -155,28 +173,27 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Logits"), - "Input(Logits) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Labels"), - "Input(Labels) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("LogitsDim"), + "Input(LogitsDim) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LabelsDim"), + "Input(LabelsDim) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Samples"), "Input(Samples) should be not null."); - PADDLE_ENFORCE(ctx->HasInput("SampledLogits"), - "Input(SampledLogits) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")), "Input(SampledLogits@Grad) should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), "Output(Logits@Grad) should be not null."); - auto logit_dims = ctx->GetInputDim("Logits"); - auto label_dims = ctx->GetInputDim("Labels"); - PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + auto logits_dims = ctx->GetInputDim("LogitsDim"); + logits_dims = framework::DDim(logits_dims.Get(), logits_dims.size() - 1); + auto labels_dims = ctx->GetInputDim("LabelsDim"); + labels_dims = framework::DDim(labels_dims.Get(), labels_dims.size() - 1); + PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL, "The label should be a 2-D tensor."); - PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL, + PADDLE_ENFORCE_EQ(logits_dims.size(), 2UL, "The logits should be a 2-D tensor."); - ctx->SetOutputDim(framework::GradVarName("Logits"), - ctx->GetInputDim("Logits")); + ctx->SetOutputDim(framework::GradVarName("Logits"), logits_dims); } protected: @@ -199,10 +216,9 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker { std::unique_ptr Apply() const override { auto* grad_op = new framework::OpDesc(); grad_op->SetType("sample_logits_grad"); - grad_op->SetInput("Logits", Input("Logits")); - grad_op->SetInput("Labels", Input("Labels")); + grad_op->SetInput("LogitsDim", Output("LogitsDim")); + grad_op->SetInput("LabelsDim", Output("LabelsDim")); grad_op->SetInput("Samples", Output("Samples")); - grad_op->SetInput("SampledLogits", Output("SampledLogits")); grad_op->SetInput(framework::GradVarName("SampledLogits"), OutputGrad("SampledLogits")); grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1b99aa05254d32405ecdac6308ad29fbca566951..282c8564deae203b05110e79b144fe3712b70cbf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6144,6 +6144,8 @@ def sampled_softmax_with_cross_entropy(logits, sampled_label = helper.create_variable_for_type_inference(dtype='int64') sampled_softlabel = helper.create_variable_for_type_inference( dtype=logits.dtype) + logits_dim = helper.create_variable_for_type_inference(dtype=logits.dtype) + labels_dim = helper.create_variable_for_type_inference(dtype=label.type) helper.append_op( type='sample_logits', @@ -6157,7 +6159,9 @@ def sampled_softmax_with_cross_entropy(logits, 'Samples': samples, 'Probabilities': probabilities, 'SampledLabels': sampled_label, - 'SampledLogits': sampled_logits + 'SampledLogits': sampled_logits, + 'LogitsDim': logits_dim, + 'LabelsDim': labels_dim }, attrs={ 'use_customized_samples': use_customized_samples,