提交 055a5de2 编写于 作者: L liuwei1031 提交者: liuwei1031

test=release/1.4 cherry-pick (#16845)

* optimize lstmp and sample_logits op, test=develop

* update op_use_default_grad_op_maker.spec, test=develop

* delete useless file,test=develop

* append 0 to dim variable to avoid memory reusage, test=develop
上级 feac69aa
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/lstmp_op.h" #include "paddle/fluid/operators/lstmp_op.h"
#include <memory>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -45,6 +46,7 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -45,6 +46,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
"Output(BatchHidden) of LSTMP operator should not be null."); "Output(BatchHidden) of LSTMP operator should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, PADDLE_ENFORCE_EQ(in_dims.size(), 2,
"Input(X)'s rank of LSTMP operator must be 2."); "Input(X)'s rank of LSTMP operator must be 2.");
...@@ -269,13 +271,47 @@ Users can choose to use fully-connected operator before LSTMP operator. ...@@ -269,13 +271,47 @@ Users can choose to use fully-connected operator before LSTMP operator.
} }
}; };
class LSTMPGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* grad_op = new framework::OpDesc();
grad_op->SetType("lstmp_grad");
grad_op->SetInput("Weight", Input("Weight"));
grad_op->SetInput("ProjWeight", Input("ProjWeight"));
grad_op->SetInput("Bias", Input("Bias"));
grad_op->SetInput("Projection", Output("Projection"));
grad_op->SetInput("Cell", Output("Cell"));
grad_op->SetInput("BatchGate", Output("BatchGate"));
grad_op->SetInput("BatchCellPreAct", Output("BatchCellPreAct"));
grad_op->SetInput("BatchHidden", Output("BatchHidden"));
grad_op->SetInput("H0", Input("H0"));
grad_op->SetInput("C0", Input("C0"));
grad_op->SetInput(framework::GradVarName("Projection"),
OutputGrad("Projection"));
grad_op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
grad_op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
grad_op->SetOutput(framework::GradVarName("ProjWeight"),
InputGrad("ProjWeight"));
grad_op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
grad_op->SetOutput(framework::GradVarName("H0"), InputGrad("H0"));
grad_op->SetOutput(framework::GradVarName("C0"), InputGrad("C0"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
class LSTMPGradOp : public framework::OperatorWithKernel { class LSTMPGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Projection"), PADDLE_ENFORCE(ctx->HasInput("Projection"),
"Input(Projection) of LSTMP operator should not be null."); "Input(Projection) of LSTMP operator should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"), PADDLE_ENFORCE(ctx->HasInput("Cell"),
...@@ -298,7 +334,8 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -298,7 +334,8 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
}; };
SetOutGradDim("Input"); ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("BatchGate"));
SetOutGradDim("Weight"); SetOutGradDim("Weight");
SetOutGradDim("ProjWeight"); SetOutGradDim("ProjWeight");
SetOutGradDim("Bias"); SetOutGradDim("Bias");
...@@ -310,7 +347,8 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -310,7 +347,8 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context()); ctx.Input<framework::LoDTensor>("BatchGate")->type(),
ctx.device_context());
} }
}; };
...@@ -318,8 +356,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -318,8 +356,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(lstmp, ops::LSTMPOp, ops::LSTMPOpMaker, REGISTER_OPERATOR(lstmp, ops::LSTMPOp, ops::LSTMPOpMaker, ops::LSTMPGradMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(lstmp_grad, ops::LSTMPGradOp); REGISTER_OPERATOR(lstmp_grad, ops::LSTMPGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
lstmp, ops::LSTMPKernel<paddle::platform::CPUDeviceContext, float>, lstmp, ops::LSTMPKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -267,7 +267,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -267,7 +267,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
} }
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight"); auto* weight = ctx.Input<Tensor>("Weight");
auto* proj_weight = ctx.Input<Tensor>("ProjWeight"); auto* proj_weight = ctx.Input<Tensor>("ProjWeight");
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
...@@ -323,7 +322,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -323,7 +322,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace()); ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
} }
auto in_dims = input->dims(); // batch_gate dims equal to input dims
auto in_dims = batch_gate->dims();
auto out_dims = cell_out->dims(); auto out_dims = cell_out->dims();
framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]}); framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]});
int frame_size = static_cast<int>(in_dims[1] / 4); int frame_size = static_cast<int>(in_dims[1] / 4);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sample_logits_op.h" #include "paddle/fluid/operators/sample_logits_op.h"
#include <memory>
#include "paddle/fluid/operators/math/sample_prob.h" #include "paddle/fluid/operators/math/sample_prob.h"
namespace paddle { namespace paddle {
...@@ -60,6 +61,10 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -60,6 +61,10 @@ class SampleLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]." "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N, NT + S]."
"The probabilites of sampled positive and negtive labels.") "The probabilites of sampled positive and negtive labels.")
.AsIntermediate(); .AsIntermediate();
AddOutput("LogitsDim", "Store dim information of Logits for gradient op")
.AsIntermediate();
AddOutput("LabelsDim", "Store dim information of Logits for gradient op")
.AsIntermediate();
AddOutput("SampledLogits", AddOutput("SampledLogits",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape" "(Tensor, default: Tensor<float>), A 2-D tensor with shape"
"[N, NT + S]. The outputs value of sampled logits, which will be" "[N, NT + S]. The outputs value of sampled logits, which will be"
...@@ -121,6 +126,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel { ...@@ -121,6 +126,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
"Output(SampledLogits) should be not null."); "Output(SampledLogits) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"), PADDLE_ENFORCE(ctx->HasOutput("SampledLabels"),
"Output(SampledLabels) should be not null."); "Output(SampledLabels) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("LogitsDim"),
"Output(LogitsDim) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("LabelsDim"),
"Output(LabelsDim) should be not null.");
auto logits_dims = ctx->GetInputDim("Logits"); auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Labels"); auto labels_dims = ctx->GetInputDim("Labels");
...@@ -137,6 +146,15 @@ class SampleLogitsOp : public framework::OperatorWithKernel { ...@@ -137,6 +146,15 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]}); ctx->SetOutputDim("SampledLabels", {logits_dims[0], labels_dims[1]});
// append 0 to shape variable to avoid optimized by memory optimize pass
auto logits_dim_vec = framework::vectorize(logits_dims);
logits_dim_vec.push_back(0);
ctx->SetOutputDim("LogitsDim", framework::make_ddim(logits_dim_vec));
auto labels_dim_vec = framework::vectorize(labels_dims);
labels_dim_vec.push_back(0);
ctx->SetOutputDim("LabelsDim", framework::make_ddim(labels_dim_vec));
} }
protected: protected:
...@@ -155,28 +173,27 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { ...@@ -155,28 +173,27 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"), PADDLE_ENFORCE(ctx->HasInput("LogitsDim"),
"Input(Logits) should not be null."); "Input(LogitsDim) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"), PADDLE_ENFORCE(ctx->HasInput("LabelsDim"),
"Input(Labels) should be not null."); "Input(LabelsDim) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Samples"), PADDLE_ENFORCE(ctx->HasInput("Samples"),
"Input(Samples) should be not null."); "Input(Samples) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("SampledLogits"),
"Input(SampledLogits) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("SampledLogits")),
"Input(SampledLogits@Grad) should not be null."); "Input(SampledLogits@Grad) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null."); "Output(Logits@Grad) should be not null.");
auto logit_dims = ctx->GetInputDim("Logits"); auto logits_dims = ctx->GetInputDim("LogitsDim");
auto label_dims = ctx->GetInputDim("Labels"); logits_dims = framework::DDim(logits_dims.Get(), logits_dims.size() - 1);
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, auto labels_dims = ctx->GetInputDim("LabelsDim");
labels_dims = framework::DDim(labels_dims.Get(), labels_dims.size() - 1);
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The label should be a 2-D tensor."); "The label should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(logit_dims.size(), 2UL, PADDLE_ENFORCE_EQ(logits_dims.size(), 2UL,
"The logits should be a 2-D tensor."); "The logits should be a 2-D tensor.");
ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->SetOutputDim(framework::GradVarName("Logits"), logits_dims);
ctx->GetInputDim("Logits"));
} }
protected: protected:
...@@ -199,10 +216,9 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker { ...@@ -199,10 +216,9 @@ class SampleLogitsGradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* grad_op = new framework::OpDesc(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("sample_logits_grad"); grad_op->SetType("sample_logits_grad");
grad_op->SetInput("Logits", Input("Logits")); grad_op->SetInput("LogitsDim", Output("LogitsDim"));
grad_op->SetInput("Labels", Input("Labels")); grad_op->SetInput("LabelsDim", Output("LabelsDim"));
grad_op->SetInput("Samples", Output("Samples")); grad_op->SetInput("Samples", Output("Samples"));
grad_op->SetInput("SampledLogits", Output("SampledLogits"));
grad_op->SetInput(framework::GradVarName("SampledLogits"), grad_op->SetInput(framework::GradVarName("SampledLogits"),
OutputGrad("SampledLogits")); OutputGrad("SampledLogits"));
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
......
...@@ -6144,6 +6144,8 @@ def sampled_softmax_with_cross_entropy(logits, ...@@ -6144,6 +6144,8 @@ def sampled_softmax_with_cross_entropy(logits,
sampled_label = helper.create_variable_for_type_inference(dtype='int64') sampled_label = helper.create_variable_for_type_inference(dtype='int64')
sampled_softlabel = helper.create_variable_for_type_inference( sampled_softlabel = helper.create_variable_for_type_inference(
dtype=logits.dtype) dtype=logits.dtype)
logits_dim = helper.create_variable_for_type_inference(dtype=logits.dtype)
labels_dim = helper.create_variable_for_type_inference(dtype=label.type)
helper.append_op( helper.append_op(
type='sample_logits', type='sample_logits',
...@@ -6157,7 +6159,9 @@ def sampled_softmax_with_cross_entropy(logits, ...@@ -6157,7 +6159,9 @@ def sampled_softmax_with_cross_entropy(logits,
'Samples': samples, 'Samples': samples,
'Probabilities': probabilities, 'Probabilities': probabilities,
'SampledLabels': sampled_label, 'SampledLabels': sampled_label,
'SampledLogits': sampled_logits 'SampledLogits': sampled_logits,
'LogitsDim': logits_dim,
'LabelsDim': labels_dim
}, },
attrs={ attrs={
'use_customized_samples': use_customized_samples, 'use_customized_samples': use_customized_samples,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册