From 721c2c00ef7d056a67d0be5364eda4435d02d166 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Fri, 15 Mar 2019 15:06:20 +0800 Subject: [PATCH] refine fc_infershape test=develop --- paddle/fluid/operators/fc_op.cc | 34 +++++++++++-------- paddle/fluid/operators/fc_op.h | 16 +++++++++ .../fused/fused_embedding_seq_pool_op.cc | 6 +++- paddle/fluid/operators/hash_op.cc | 6 +++- paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 28 +++++++++------ .../sequence_ops/sequence_enumerate_op.cc | 6 +++- 6 files changed, 68 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index eb4617a9359..033eca967ac 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -55,17 +55,8 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { "The input tensor Input's rank of FCOp should be larger than " "in_num_col_dims."); - auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims); - PADDLE_ENFORCE_EQ( - in_mat_dims[1], w_dims[0], - "Fully Connected input and weigth size do not match. %s, %s"); - std::vector output_dims; - output_dims.reserve(static_cast(in_num_col_dims + 1)); - for (int i = 0; i < in_num_col_dims; ++i) { - output_dims.push_back(in_dims[i]); - } - output_dims.push_back(w_dims[1]); + FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims); ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->ShareLoD("Input", "Out"); @@ -128,6 +119,12 @@ void FCOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + framework::kAllKernelsMustComputeRuntimeShape, + "If an Op has this attribute, all its kernels should calculate output" + "variable's shape in the corresponding Compute() function. Note that " + "this temporal attribute would be deleted after all ops contain it.") + .SetDefault(true); AddComment(R"DOC( Fully Connected Operator. @@ -142,13 +139,20 @@ class FCOpKernel : public framework::OpKernel { void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - auto input = ctx.Input("Input"); - auto w = ctx.Input("W"); - auto bias = ctx.Input("Bias"); - auto output = ctx.Output("Out"); + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + auto bias = ctx.Input("Bias"); + auto output = ctx.Output("Out"); + int in_num_col_dims = ctx.Attr("in_num_col_dims"); auto w_dims = w->dims(); + + std::vector output_dims; + FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims); + output->Resize(framework::make_ddim(output_dims)); + output->set_lod(input->lod()); + auto out_dims = output->dims(); - int M = framework::product(out_dims) / out_dims[out_dims.size() - 1]; + int M = framework::product(out_dims) / w_dims[1]; const T* input_data = input->data(); const T* w_data = w->data(); diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h index e1b780fc0c4..b82a63cd830 100644 --- a/paddle/fluid/operators/fc_op.h +++ b/paddle/fluid/operators/fc_op.h @@ -48,5 +48,21 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override; }; +inline void FCOutputSize(const framework::DDim& in_dims, + const framework::DDim& w_dims, + std::vector& out_dims, // NOLINT + int in_num_col_dims) { + auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims); + PADDLE_ENFORCE_EQ( + in_mat_dims[1], w_dims[0], + "Fully Connected input and weigth size do not match. %s, %s"); + + out_dims.reserve(static_cast(in_num_col_dims + 1)); + for (int i = 0; i < in_num_col_dims; ++i) { + out_dims.push_back(in_dims[i]); + } + out_dims.push_back(w_dims[1]); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index a0026427e25..40a411985c1 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -88,7 +88,11 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { "(boolean, default false) " "Sparse update.") .SetDefault(false); - AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "") + AddAttr( + framework::kAllKernelsMustComputeRuntimeShape, + "If an Op has this attribute, all its kernels should calculate output" + "variable's shape in the corresponding Compute() function. Note that " + "this temporal attribute would be deleted after all ops contain it.") .SetDefault(true); AddComment(R"DOC( FusedEmbeddingSeqPool Operator. diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc index f6395fb32fe..4deee8b4336 100644 --- a/paddle/fluid/operators/hash_op.cc +++ b/paddle/fluid/operators/hash_op.cc @@ -54,7 +54,11 @@ $$Out = scale * X$$ )DOC"); AddAttr("num_hash", "").SetDefault(1); AddAttr("mod_by", "").SetDefault(100000); - AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "") + AddAttr( + framework::kAllKernelsMustComputeRuntimeShape, + "If an Op has this attribute, all its kernels should calculate output" + "variable's shape in the corresponding Compute() function. Note that " + "this temporal attribute would be deleted after all ops contain it.") .SetDefault(true); } }; diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 3a926a716f5..2bdf146f4dd 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -123,9 +123,9 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - auto input = ctx.Input("Input"); - auto w = ctx.Input("W"); - auto bias = ctx.Input("Bias"); + auto input = ctx.Input("Input"); + auto w = ctx.Input("W"); + auto bias = ctx.Input("Bias"); PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4, "Input must be with 2 or 4 dimensions, i.e. NCHW"); @@ -151,7 +151,13 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel { const T* input_data = input->data(); const T* w_data = w->data(); - auto output = ctx.Output("Out"); + auto output = ctx.Output("Out"); + int in_num_col_dims = ctx.Attr("in_num_col_dims"); + std::vector output_dims; + FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims); + output->Resize(framework::make_ddim(output_dims)); + output->set_lod(input->lod()); + T* output_data = output->mutable_data(ctx.GetPlace()); auto dst_memory = mem.dst(output_data); @@ -204,19 +210,21 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel { Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); Tensor* w_grad = ctx.Output(framework::GradVarName("W")); + const Tensor* input = ctx.Input("Input"); + const T* input_data = input->data(); + + const Tensor* w = ctx.Input("W"); + const T* w_data = w->data(); + if (input_grad) { + input_grad->Resize(input->dims()); input_grad_data = input_grad->mutable_data(ctx.GetPlace()); } if (w_grad) { + w_grad->Resize(w->dims()); w_grad_data = w_grad->mutable_data(ctx.GetPlace()); } - const Tensor* input = ctx.Input("Input"); - const T* input_data = input->data(); - - const Tensor* w = ctx.Input("W"); - const T* w_data = w->data(); - const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); const T* out_grad_data = out_grad->data(); diff --git a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc index f357c9c08d0..75bcd3c47f4 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc @@ -59,7 +59,11 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("pad_value", "(int) The enumerate sequence padding value.") .SetDefault(0); - AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "") + AddAttr( + framework::kAllKernelsMustComputeRuntimeShape, + "If an Op has this attribute, all its kernels should calculate output" + "variable's shape in the corresponding Compute() function. Note that " + "this temporal attribute would be deleted after all ops contain it.") .SetDefault(true); AddComment(R"DOC( Sequence Enumerate Operator. -- GitLab