提交 02909335 编写于 作者: T tensor-tang

rename fusion seq_concat_fc to fusion seqexpand_concat_fc

上级 0f0d4823
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seq_concat_fc_op.h"
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
......@@ -22,15 +22,20 @@ limitations under the License. */
namespace paddle {
namespace operators {
void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqConcatFCOp should larger than 1.");
PADDLE_ENFORCE(ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConcatFC should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqConcatFC should not be null.");
void FusionSeqExpandConcatFCOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT(
ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
PADDLE_ENFORCE(
ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
......@@ -55,14 +60,14 @@ void FusionSeqConcatFCOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Out", 0);
}
framework::OpKernelType FusionSeqConcatFCOp::GetExpectedKernelType(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionSeqConcatFCOpMaker::Make() {
void FusionSeqExpandConcatFCOpMaker::Make() {
AddInput("X",
"(LoDTensor) input LodDTensors, the first one must be have ref lod "
"for sequence expand, and the rest input should have same lod.")
......@@ -100,7 +105,7 @@ The concat axis should be 1.
}
template <typename T>
class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
......@@ -188,10 +193,10 @@ class FusionSeqConcatFCKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seq_concat_fc, ops::FusionSeqConcatFCOp,
ops::FusionSeqConcatFCOpMaker,
REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqExpandConcatFCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seq_concat_fc,
ops::FusionSeqConcatFCKernel<float>,
ops::FusionSeqConcatFCKernel<double>);
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqExpandConcatFCOpKernel<float>,
ops::FusionSeqExpandConcatFCOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -21,7 +21,7 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqConcatFCOp : public framework::OperatorWithKernel {
class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -32,7 +32,8 @@ class FusionSeqConcatFCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqConcatFCOpMaker : public framework::OpProtoAndCheckerMaker {
class FusionSeqExpandConcatFCOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
......
......@@ -51,7 +51,7 @@ class TestFusionSeqExpandConcatFCOp(OpTest):
pass
def setUp(self):
self.op_type = 'fusion_seq_concat_fc'
self.op_type = 'fusion_seqexpand_concat_fc'
self.lod = [[3, 5, 8, 2]]
self.inputs_M = [15, 10, 10]
self.D = 20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册