未验证 提交 320451d6 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12990 from tensor-tang/feature/op/fusion_expand_concat_fc

Feature fusion expand concat fc
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seqexpand_concat_fc_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
void FusionSeqExpandConcatFCOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_GT(
ctx->Inputs("X").size(), 1UL,
"Inputs(X) of FusionSeqExpandConcatFCOp should larger than 1.");
PADDLE_ENFORCE(
ctx->HasInput("FCWeight"),
"Input(FCWeight) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqExpandConcatFCOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("FCOut"),
"Output(FCOut) of FusionSeqExpandConcatFCOp should not be null.");
auto ins_dims = ctx->GetInputsDim("X");
auto w_dims = ctx->GetInputDim("FCWeight"); // (M0+M1+M2+..) x D
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL, "Input(FCWeight)'s rank must be 2.");
const int D = w_dims[1];
int sum = ins_dims[0][1];
for (size_t i = 1; i < ins_dims.size(); ++i) {
sum += ins_dims[i][1];
}
PADDLE_ENFORCE_EQ(sum, w_dims[0],
"FC height should be sum of all inputs width.");
if (ctx->HasInput("FCBias")) {
auto b_dims = ctx->GetInputDim("FCBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D);
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D);
}
ctx->SetOutputDim("Out", {ins_dims[0][0], D});
// fcout should be reshape when run since can not get lod in infershape
// explicit share the ref lod
ctx->ShareLoD("X", "Out", 0);
}
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
ctx.device_context());
}
void FusionSeqExpandConcatFCOpMaker::Make() {
AddInput("X",
"(LoDTensor) input LodDTensors, the first one must be have ref lod "
"for sequence expand, and the rest input should have same lod.")
.AsDuplicable();
AddInput("FCWeight", "(Tensor) the weights of fc.");
AddInput("FCBias", "(Tensor, optional) the bias of fc.").AsDispensable();
AddOutput("Out", "(LoDTensor) Output LodTensor.");
AddOutput(
"FCOut",
"(Tensor) the intermediate tensor to keep the result of fc."
"Shape is (N x D), where N is the batch size, D is the output dim of fc")
.AsIntermediate();
AddAttr<std::string>("fc_activation",
"(string, default: identity)"
"The activation for the result of fc."
"`identity` by default.")
.SetDefault("identity")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Fusion Sequence expand + concat + fc Operator.
All below conditions should be meet:
The ref_level of seq_expand should be 0.
The ref lod of seq_expand level is the first input of concat.
The other inputs should have same lod and same batch size of ref lod.
The seq len of other inputs should be 1.
The concat axis should be 1.
)DOC");
}
template <typename T>
class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto ins = ctx.MultiInput<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("FCWeight");
auto* b = ctx.Input<Tensor>("FCBias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* fc_out = ctx.Output<Tensor>("FCOut");
auto* ref_in = ins[0];
auto ref_lod = ref_in->lod();
auto in1_lod = ins[1]->lod();
auto ref_dims = ref_in->dims(); // T x M0
auto in1_dims = ins[1]->dims(); // N x M1
auto w_dims = w->dims();
const int N = ref_lod[0].size() - 1;
const int total_T = ref_dims[0];
const int M0 = ref_dims[1];
const int M1 = in1_dims[1];
const int D = w_dims[1];
// some check and fcout should be reshape here
// since infershape can not get lod info
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
"Batch size of all inputs should be equal.");
PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
"Seq_length of other inputs should be 1.");
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
for (size_t i = 2; i < ins.size(); ++i) {
PADDLE_ENFORCE_EQ(ins[i]->dims()[0], N,
"All other inputs height should be equal");
PADDLE_ENFORCE_EQ(ins[i]->lod(), in1_lod,
"All other inputs should have same lod");
}
fc_out->Resize({N, D});
std::function<void(const int, const T*, T*)> fc_act;
auto& fc_act_str = ctx.Attr<std::string>("fc_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
fc_act = act_functor(fc_act_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
fc_act = act_functor(fc_act_str);
}
const T* ref_in_data = ref_in->data<T>();
const T* in1_data = ins[1]->data<T>();
const T* w_data = w->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL);
w_data = w_data + M0 * D;
// first write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
w_data = w_data + M1 * D;
for (size_t i = 2; i < ins.size(); ++i) {
// add on
const T* in_data = ins[i]->data<T>();
const int K = ins[i]->dims()[1];
blas.GEMM(CblasNoTrans, CblasNoTrans, N, D, K, static_cast<T>(1), in_data,
K, w_data, D, static_cast<T>(1), fc_out_data, D);
w_data = w_data + K * D;
}
T* cur_out_data = out_data;
for (int i = 0; i < N; ++i) {
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
T* src = fc_out_data + i * D;
for (int step = 0; step < seq_len; ++step) {
blas.VADD(D, cur_out_data, src, cur_out_data);
cur_out_data = cur_out_data + D;
}
}
fc_act(total_T * D, out_data, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqexpand_concat_fc, ops::FusionSeqExpandConcatFCOp,
ops::FusionSeqExpandConcatFCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seqexpand_concat_fc,
ops::FusionSeqExpandConcatFCOpKernel<float>,
ops::FusionSeqExpandConcatFCOpKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqExpandConcatFCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqExpandConcatFCOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_fusion_lstm_op import fc, ACTIVATION
def fusion_seqexpand_concat_fc(xs, lod, w, b, fc_act):
T = sum(lod[0])
N = len(lod[0])
num_inputs = len(xs)
D = w.shape[1]
expanded_inputs = [xs[0]]
for i in range(num_inputs - 1):
x = xs[i + 1]
assert x.shape[0] == N
expanded = np.repeat(x, lod[0], axis=0)
assert expanded.shape[0] == T
assert expanded.shape[1] == x.shape[1]
expanded_inputs.append(expanded)
fc_input = np.concatenate(expanded_inputs, axis=1)
assert fc_input.shape[0] == T
assert fc_input.shape[1] == w.shape[0]
fc_out = fc(fc_input, w, b)
fc_out = fc_act(fc_out)
assert fc_out.shape[0] == T
assert fc_out.shape[1] == D
return fc_out
class TestFusionSeqExpandConcatFCOp(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'fusion_seqexpand_concat_fc'
self.lod = [[3, 5, 8, 2]]
self.inputs_M = [15, 10, 10]
self.D = 20
self.with_bias = True
self.fc_act = 'relu'
self.set_conf()
T = sum(self.lod[0])
bs = len(self.lod[0])
num_inputs = len(self.inputs_M)
x0 = np.random.normal(size=(T, self.inputs_M[0])).astype('float32')
xs = [x0]
for i in range(num_inputs - 1):
xi = np.random.normal(size=(bs,
self.inputs_M[i + 1])).astype('float32')
xs.append(xi)
# fc weight and bias
w = np.random.normal(size=(sum(self.inputs_M),
self.D)).astype('float32')
b = np.random.normal(size=(
1, self.D)).astype('float32') if self.with_bias else np.zeros(
(1, self.D)).astype('float32')
out = fusion_seqexpand_concat_fc(xs, self.lod, w, b,
ACTIVATION[self.fc_act])
self.inputs = {'X': [('x0', (x0, self.lod))], 'FCWeight': w}
normal_lod = [[1] * bs]
for i in range(num_inputs - 1):
self.inputs['X'].append(('x%d' % (i + 1), (xs[i + 1], normal_lod)))
if self.with_bias:
self.inputs['FCBias'] = b
self.outputs = {'Out': (out, self.lod)}
self.attrs = {'fc_activation': self.fc_act}
def test_check_output(self):
self.check_output()
class TestFusionSECFCOpNonBias(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.with_bias = False
class TestFusionSECFCOpNonAct(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.fc_act = 'identity'
class TestFusionSECFCOpMD1(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.inputs_M = [3, 4, 2, 1, 5]
self.D = 8
class TestFusionSECFCOpMD2(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.lod = [[5, 6]]
self.inputs_M = [1, 1]
class TestFusionSECFCOpBS1_1(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.lod = [[1]]
self.inputs_M = [3, 4, 2]
class TestFusionSECFCOpBS1_2(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.lod = [[1]]
self.inputs_M = [3, 4]
class TestFusionSECFCOpBS1_3(TestFusionSeqExpandConcatFCOp):
def set_conf(self):
self.lod = [[5]]
self.inputs_M = [6, 3]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册