From 02909335e9208dc9c1a8835b6e25b708ea366005 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 27 Aug 2018 16:01:29 +0800 Subject: [PATCH] rename fusion seq_concat_fc to fusion seqexpand_concat_fc --- ...op.cc => fusion_seqexpand_concat_fc_op.cc} | 41 +++++++++++-------- ...c_op.h => fusion_seqexpand_concat_fc_op.h} | 7 ++-- ... => test_fusion_seqexpand_concat_fc_op.py} | 2 +- 3 files changed, 28 insertions(+), 22 deletions(-) rename paddle/fluid/operators/{fusion_seq_concat_fc_op.cc => fusion_seqexpand_concat_fc_op.cc} (85%) rename paddle/fluid/operators/{fusion_seq_concat_fc_op.h => fusion_seqexpand_concat_fc_op.h} (82%) rename python/paddle/fluid/tests/unittests/{test_fusion_seq_concat_fc_op.py => test_fusion_seqexpand_concat_fc_op.py} (98%) diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc similarity index 85% rename from paddle/fluid/operators/fusion_seq_concat_fc_op.cc rename to paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc index f61c822abf6..641851585d4 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h" +#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h" #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" @@ -22,15 +22,20 @@ limitations under the License. */ namespace paddle { namespace operators { -void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL, - "Inputs(X) of FusionSeqConcatFCOp should larger than 1."); - PADDLE_ENFORCE(ctx->HasInput("FCWeight"), - "Input(FCWeight) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of FusionSeqConcatFC should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("FCOut"), - "Output(FCOut) of FusionSeqConcatFC should not be null."); +void FusionSeqExpandConcatFCOp::InferShape( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE_GT( + ctx->Inputs("X").size(), 1UL, + "Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1."); + PADDLE_ENFORCE( + ctx->HasInput("FCWeight"), + "Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FusionSeqExpandConcatFCOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("FCOut"), + "Output(FCOut) of FusionSeqExpandConcatFCOp should not be null."); auto ins_dims = ctx->GetInputsDim("X"); auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D @@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Out", 0); } -framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( +framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } -void FusionSeqConcatFCOpMaker::Make() { +void FusionSeqExpandConcatFCOpMaker::Make() { AddInput("X", "(LoDTensor) input LodDTensors, the first one must be have ref lod " "for sequence expand, and the rest input should have same lod.") @@ -100,7 +105,7 @@ The concat axis should be 1. } template -class FusionSeqConcatFCKernel : public framework::OpKernel { +class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using DeviceContext = paddle::platform::CPUDeviceContext; @@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp, - ops::FusionSeqConcatFCOpMaker, +REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp, + ops::FusionSeqExpandConcatFCOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc, - ops::FusionSeqConcatFCKernel, - ops::FusionSeqConcatFCKernel); +REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc, + ops::FusionSeqExpandConcatFCOpKernel, + ops::FusionSeqExpandConcatFCOpKernel); diff --git a/paddle/fluid/operators/fusion_seq_concat_fc_op.h b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h similarity index 82% rename from paddle/fluid/operators/fusion_seq_concat_fc_op.h rename to paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h index 66ac48f4c1c..f78e820f603 100644 --- a/paddle/fluid/operators/fusion_seq_concat_fc_op.h +++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; -class FusionSeqConcatFCOp : public framework::OperatorWithKernel { +class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override; }; -class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker { +class FusionSeqExpandConcatFCOpMaker + : public framework::OpProtoAndCheckerMaker { public: void Make() override; }; diff --git a/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py b/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py similarity index 98% rename from python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py rename to python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py index a389b605f0a..7baf39eb3f4 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_seq_concat_fc_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_seqexpand_concat_fc_op.py @@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest): pass def setUp(self): - self.op_type = 'fusion_seq_concat_fc' + self.op_type = 'fusion_seqexpand_concat_fc' self.lod = [[3, 5, 8, 2]] self.inputs_M = [15, 10, 10] self.D = 20 -- GitLab