提交 02909335 编写于 作者: T tensor-tang

rename fusion seq_concat_fc to fusion seqexpand_concat_fc

上级 0f0d4823
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h" #include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
...@@ -22,15 +22,20 @@ limitations under the License. */ ...@@ -22,15 +22,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { void FusionSeqExpandConcatFCOp::InferShape(
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL, framework::InferShapeContext* ctx) const {
"Inputs(X) of FusionSeqConcatFCOp should larger than 1."); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE(ctx->HasInput("FCWeight"), ctx->Inputs("X").size(), 1UL,
"Input(FCWeight) of FusionSeqConcatFC should not be null."); "Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(
"Output(Out) of FusionSeqConcatFC should not be null."); ctx->HasInput("FCWeight"),
PADDLE_ENFORCE(ctx->HasOutput("FCOut"), "Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
"Output(FCOut) of FusionSeqConcatFC should not be null."); PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X"); auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
...@@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Out", 0); ctx->ShareLoD("X", "Out", 0);
} }
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context()); ctx.device_context());
} }
void FusionSeqConcatFCOpMaker::Make() { void FusionSeqExpandConcatFCOpMaker::Make() {
AddInput("X", AddInput("X",
"(LoDTensor) input LodDTensors, the first one must be have ref lod " "(LoDTensor) input LodDTensors, the first one must be have ref lod "
"for sequence expand, and the rest input should have same lod.") "for sequence expand, and the rest input should have same lod.")
...@@ -100,7 +105,7 @@ The concat axis should be 1. ...@@ -100,7 +105,7 @@ The concat axis should be 1.
} }
template <typename T> template <typename T>
class FusionSeqConcatFCKernel : public framework::OpKernel<T> { class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
...@@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> { ...@@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp, REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqConcatFCOpMaker, ops::FusionSeqExpandConcatFCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc, REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqConcatFCKernel<float>, ops::FusionSeqExpandConcatFCOpKernel<float>,
ops::FusionSeqConcatFCKernel<double>); ops::FusionSeqExpandConcatFCOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class FusionSeqConcatFCOp : public framework::OperatorWithKernel { class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel { ...@@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
}; };
class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker { class FusionSeqExpandConcatFCOpMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override; void Make() override;
}; };
......
...@@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest): ...@@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest):
pass pass
def setUp(self): def setUp(self):
self.op_type = 'fusion_seq_concat_fc' self.op_type = 'fusion_seqexpand_concat_fc'
self.lod = [[3, 5, 8, 2]] self.lod = [[3, 5, 8, 2]]
self.inputs_M = [15, 10, 10] self.inputs_M = [15, 10, 10]
self.D = 20 self.D = 20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册