未验证 提交 664159ad 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #13998 from tensor-tang/fea/fusion_seqconv_add

Fea/fusion seqconv eltadd relu
......@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference)
......
......@@ -761,6 +761,51 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
seqconv_input->assert_is_op_input("sequence_conv", "X");
auto *seqconv_op = pattern->NewNode(seqconv_repr())
->assert_is_op("sequence_conv")
->assert_op_attr<bool>("paddingTrainable", false)
->assert_op_attr<int>("contextStride", 1);
auto *eltadd_op =
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
// Create variables
// Filter
auto *seqconv_weight_var =
pattern->NewNode(seqconv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("sequence_conv", "Filter");
// Bias
auto *eltadd_bias_var = pattern->NewNode(eltadd_bias_repr())
->AsInput()
->assert_is_op_input("elementwise_add");
// intermediate variable, will be removed in the IR after fuse.
auto *seqconv_out_var = pattern->NewNode(seqconv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("sequence_conv")
->assert_is_op_input("elementwise_add");
auto *eltadd_out_var = pattern->NewNode(eltadd_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("elementwise_add")
->assert_is_only_input_of_op("relu");
// output
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
seqconv_op->LinksFrom({seqconv_input, seqconv_weight_var})
.LinksTo({seqconv_out_var});
eltadd_op->LinksFrom({seqconv_out_var, eltadd_bias_var})
.LinksTo({eltadd_out_var});
relu_op->LinksFrom({eltadd_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
......
......@@ -128,6 +128,15 @@ struct PDNode {
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
template <typename T>
PDNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
return x && x->IsOp() && x->Op()->HasAttr(attr_name) &&
boost::get<T>(x->Op()->GetAttr(attr_name)) == attr;
});
return this;
}
private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
......@@ -434,6 +443,31 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
// seqconv_input, seqconv_weight,
// seqconv_out, seqconv,
// elementwise_add_bias, elementwise_add_out, elementwise_add
// relu_out, relu
struct SeqConvEltAddRelu : public PatternBase {
SeqConvEltAddRelu(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "seqconv_eltadd_relu") {}
PDNode* operator()(PDNode* seqconv_input);
// declare operator node's name
PATTERN_DECL_NODE(seqconv);
PATTERN_DECL_NODE(eltadd);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(seqconv_weight);
PATTERN_DECL_NODE(seqconv_out);
PATTERN_DECL_NODE(eltadd_bias);
PATTERN_DECL_NODE(eltadd_out);
PATTERN_DECL_NODE(relu_out);
};
// FC with bias
// op: mul + elementwise_add
// named nodes:
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
->assert_is_op_input("sequence_conv")
->assert_var_not_persistable();
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
fuse_pattern(x);
// Create New OpDesc
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
Node* eltadd_bias, Node* relu_out) {
OpDesc op_desc;
op_desc.SetType("fusion_seqconv_eltadd_relu");
op_desc.SetInput("X", {input->Name()});
op_desc.SetInput("Filter", {seqconv_weight->Name()});
op_desc.SetInput("Bias", {eltadd_bias->Name()});
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
op_desc.SetOutput("ColMat", {ColMat});
op_desc.SetOutput("Out", {relu_out->Name()});
scope->Var(ColMat)->GetMutable<LoDTensor>();
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(seqconv_weight, op);
IR_NODE_LINK_TO(eltadd_bias, op);
IR_NODE_LINK_TO(op, relu_out);
return op;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
relu_out);
std::unordered_set<const Node*> marked_nodes(
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
paddle::framework::ir::SeqConvEltAddReluFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConvEltAddReluFusePass : public FusePassBase {
public:
virtual ~SeqConvEltAddReluFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -67,17 +67,18 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion.
const std::vector<std::string> all_ir_passes_{{
// Manual update the passes here.
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
"embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
#ifdef PADDLE_WITH_MKLDNN
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
......
......@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
SetConfig(&cfg);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
GetFuseStatis(predictor.get(), &num_ops);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 32);
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -86,7 +86,7 @@ function(op_library TARGET)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
"channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h"
#include <algorithm> // for min, max
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
namespace paddle {
namespace operators {
void FusionSeqConvEltAddReluOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Filter"),
"Input(Filter) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("Bias"),
"Input(Bias) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FusionSeqConvEltAddReluOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColMat"),
"Output(ColMat) of FusionSeqConvEltAddReluOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto w_dims = ctx->GetInputDim("Filter");
int context_length = ctx->Attrs().Get<int>("contextLength");
PADDLE_ENFORCE(
ctx->Attrs().Get<int>("contextStride") == 1,
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1.");
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(x_dims.size() == 2 && w_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE(w_dims[0] == context_length * x_dims[1],
"Filter's height should be context_length * "
"input_hidden_size .");
PADDLE_ENFORCE_GT(context_length + ctx->Attrs().Get<int>("contextStart"), 0,
"contextStart size should be smaller than contextLength.");
ctx->SetOutputDim("Out", {x_dims[0], w_dims[1]});
ctx->SetOutputDim("ColMat", {x_dims[0], w_dims[0]});
ctx->ShareLoD("X", "Out");
}
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionSeqConvEltAddReluOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
// PaddingData only support false yet, should be ensured at pass.
AddInput("Filter",
"(Tensor) same as the input(Filter) of sequence conv op is an "
"learnable parameter."
"This is a tensor with shape (K, N), where K is the "
"context_length * dim size of x, N is the output feature size.");
AddInput("Bias",
"(Tensor) the learnable weights. shape (1, N), where N is the "
"output feature size");
AddOutput(
"Out",
"(LoDTensor) the output(Out) is a LodTensor, which support "
"variable-time length output sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, N), where, T is the "
"total time steps in this mini-batch, N is the output feature size.");
AddOutput("ColMat",
"(Tensor) (T, K), where T is where T is the "
"total time steps in this mini-batch, K is height of Filter")
.AsIntermediate();
AddAttr<int>("contextLength",
"(int) the contextLength of FusionSeqConvEltAddReluOp is the "
"height of the convolution kernel.")
.GreaterThan(0);
AddAttr<int>("contextStart",
"(int, default:0) the contextStart of FusionSeqConvEltAddReluOp "
"represents the beginning of the convolution of the number of "
"rows of sequence, which can be negative. The negative number "
"means to pad contextStart time-steps of zeros or learnable "
"parameters at the beginning of each instance. The positive "
"number means to skip contextStart time-steps of each "
"instance.")
.SetDefault(0);
AddAttr<int>(
"contextStride",
"(int, default:1) the contextStride of FusionSeqConvEltAddReluOp "
"represents the stride length of convolution kernel. "
"Currently, FusionSeqConvEltAddReluOp only supports"
"contextStride=1.")
.SetDefault(1)
.GreaterThan(0);
AddComment(R"DOC(
Fusion Sequence Conv and ElementwiseAdd Operator.
)DOC");
}
template <typename T>
class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* w = ctx.Input<Tensor>("Filter");
auto* b = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<LoDTensor>("Out");
auto* col = ctx.Output<Tensor>("ColMat");
auto x_lod = x->lod();
auto x_dims = x->dims();
auto w_dims = w->dims();
PADDLE_ENFORCE_EQ(b->numel(), w_dims[1],
"bias size should be equal to output feature size.");
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL,
"Only support one level sequence now.");
const T* x_data = x->data<T>();
const T* w_data = w->data<T>();
const T* b_data = b->data<T>();
T* y_data = y->mutable_data<T>(ctx.GetPlace());
T* col_data = col->mutable_data<T>(ctx.GetPlace());
int context_start = ctx.Attr<int>("contextStart");
int context_length = ctx.Attr<int>("contextLength");
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
// im2col
int src_mat_w = static_cast<int>(x_dims[1]);
int src_mat_w_sz = src_mat_w * sizeof(T);
int col_mat_w = static_cast<int>(w_dims[0]);
int col_mat_w_sz = col_mat_w * sizeof(T);
for (int i = 0; i < static_cast<int>(x_lod[0].size()) - 1; ++i) {
int st = x_lod[0][i];
int ed = x_lod[0][i + 1];
const T* src_data = x_data + st * src_mat_w;
T* dst_data = col_data + st * col_mat_w;
int seq_len = ed - st;
if (seq_len > up_pad + down_pad) {
// zero all up_pad and fill data
std::memset(dst_data, 0, up_pad * col_mat_w_sz);
dst_data = dst_data + up_pad * src_mat_w;
int copy_size = col_mat_w_sz - up_pad * src_mat_w_sz;
for (int j = 0; j < up_pad; ++j) {
// blas.VCOPY?
std::memcpy(dst_data, src_data, copy_size);
dst_data += (col_mat_w - src_mat_w);
copy_size += src_mat_w_sz;
}
// fill data
for (int j = 0; j < seq_len - up_pad - down_pad; ++j) {
std::memcpy(dst_data, src_data, copy_size);
dst_data += col_mat_w;
src_data += src_mat_w;
}
// zero all down_pad and fill data
std::memset(dst_data, 0, down_pad * col_mat_w_sz);
copy_size -= src_mat_w_sz;
for (int j = 0; j < down_pad; ++j) {
std::memcpy(dst_data, src_data, copy_size);
dst_data += col_mat_w;
src_data += src_mat_w;
copy_size -= src_mat_w_sz;
}
} else {
PADDLE_ENFORCE_GE(context_length, up_pad + down_pad + 1);
std::memset(dst_data, 0, seq_len * col_mat_w_sz);
dst_data = dst_data + up_pad * src_mat_w;
int zero_sz = up_pad * src_mat_w_sz;
int cur_src_sz = seq_len * src_mat_w_sz;
for (int j = 0; j < std::min(up_pad, seq_len); ++j) {
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
std::memcpy(dst_data, src_data, copy_size);
dst_data += (col_mat_w - src_mat_w);
zero_sz -= src_mat_w_sz;
}
// from bottom
dst_data = col_data + ed * col_mat_w;
src_data = x_data + st * src_mat_w;
zero_sz = down_pad * src_mat_w_sz;
for (int j = 1; j <= std::min(down_pad, seq_len); ++j) {
int copy_size = std::min(cur_src_sz, col_mat_w_sz - zero_sz);
std::memcpy(dst_data - (zero_sz + copy_size) / sizeof(T),
src_data + std::max(seq_len - j - up_pad, 0) * src_mat_w,
copy_size);
dst_data -= col_mat_w;
zero_sz -= src_mat_w_sz;
}
}
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
col_data, w_data, y_data, b_data, true);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_seqconv_eltadd_relu, ops::FusionSeqConvEltAddReluOp,
ops::FusionSeqConvEltAddReluOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(fusion_seqconv_eltadd_relu,
ops::FusionSeqConvEltAddReluKernel<float>,
ops::FusionSeqConvEltAddReluKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionSeqConvEltAddReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionSeqConvEltAddReluOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
DECLARE_int32(paddle_num_threads);
......@@ -30,20 +31,25 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
if (B == NULL) {
return;
}
if (relu) {
const auto& vaddrelu = jitkernel::KernelPool::Instance()
.template Get<jitkernel::VAddReluKernel<T>>(N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
vaddrelu->Compute(B, dst, dst);
}
} else {
const auto& vadd = jitkernel::KernelPool::Instance()
.template Get<jitkernel::VAddKernel<T>>(N);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
for (int i = 0; i < M; i++) {
T* dst = Y + i * N;
vadd->Compute(B, dst, dst);
}
}
if (!relu) {
return;
}
// TODO(TJ): fuse relu
LOG(FATAL) << "Not implemented!";
}
} // namespace math
......
......@@ -86,6 +86,12 @@ class VAddBiasKernel : public Kernel {
virtual void Compute(const T a, const T *x, T *y) const = 0;
};
template <typename T>
class VAddReluKernel : public Kernel {
public:
virtual void Compute(const T *x, const T *y, T *z) const = 0;
};
template <typename T>
class VActKernel : public Kernel {
public:
......
......@@ -378,11 +378,99 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {}
};
/* VAddRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, const T* y, T* z) const override {
for (int i = 0; i < this->num_; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx = _mm256_loadu_ps(x); \
__m256 tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
_mm256_storeu_ps(z, tmpy); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(y); \
tmp0 = _mm256_add_ps(tmp0, tmp1); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
tmp1 = _mm256_add_ps(tmp1, tmp2); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(z, tmp0); \
_mm256_storeu_ps(z + 8, tmp1); \
}
#define INTRI_COMMON_FLOAT(isa, block) \
template <> \
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
: VAddReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
} \
template <> \
void VAddReluKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmpx = _mm256_loadu_ps(x + i); \
__m256 tmpy = _mm256_loadu_ps(y + i); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, zeros); \
_mm256_storeu_ps(z + i, tmpy); \
} \
for (int i = this->end_; i < this->num_; ++i) { \
z[i] = x[i] + y[i]; \
z[i] = z[i] > 0 ? z[i] : 0; \
} \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_COMMON_FLOAT(jit::avx, kGT16);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_COMMON_FLOAT(jit::avx2, kGT16);
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
} // namespace jitkernel
......
......@@ -712,6 +712,63 @@ TEST(JitKernel, vadd) {
}
}
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
void vaddrelu_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z) {
vadd->Compute(x, y, z);
vrelu->Compute(z, z);
}
TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddReluKernel<float>>(d);
const auto& vadd =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
const auto& vrelu =
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data);
}
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
}
auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4;
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import random
from op_test import OpTest
from test_seq_conv import seqconv
class TestSeqConvEltAddRelu(OpTest):
def set_conf(self):
pass
def setUp(self):
self.op_type = 'fusion_seqconv_eltadd_relu'
self.lod = [[6, 4]]
self.in_fea_size = 16
self.out_fea_size = 8
self.context_length = 4
self.context_stride = 1
self.context_start = 0
self.set_conf()
assert self.context_stride == 1
T = sum(self.lod[0])
x = np.random.uniform(-1, 1, [T, self.in_fea_size]).astype('float32')
w = np.random.uniform(
-1, 1, [self.in_fea_size * self.context_length,
self.out_fea_size]).astype('float32')
b = np.random.uniform(-2, 1, [1, self.out_fea_size]).astype('float32')
out = seqconv(x, self.lod, w, self.context_length, self.context_start)
out = np.maximum(out + b, 0)
self.inputs = {'X': (x, self.lod), 'Filter': w, 'Bias': b}
self.attrs = {
'contextStart': self.context_start,
'contextLength': self.context_length,
'contextStride': self.context_stride
}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestSeqConvEltAddReluBS1(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[10]]
class TestSeqConvEltAddReluBS1Case2(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[2]]
class TestSeqConvEltAddReluCase1(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[3, 5, 1, 6]]
self.context_length = 3
self.context_start = -2
class TestSeqConvEltAddReluCase2(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
self.in_fea_size = 2
self.context_length = 4
self.context_start = -1
class TestSeqConvEltAddReluCase3(TestSeqConvEltAddRelu):
def set_conf(self):
self.lod = [[10, 1, 2, 4, 1, 5, 6]]
self.context_length = 5
self.context_start = -4
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,53 @@ import random
from op_test import OpTest
def seqconv(x,
lod,
filter,
context_length,
context_start,
padding_trainable=False,
padding_data=None):
[T, M] = x.shape
col = np.zeros((T, context_length * M)).astype('float32')
offset = [0]
for seq_len in lod[0]:
offset.append(offset[-1] + seq_len)
begin_pad = np.max([0, -context_start])
for i in range(len(offset) - 1):
for j in range(context_length):
in_begin = offset[i] + context_start + j
in_end = offset[i + 1] + context_start + j
out_begin = offset[i]
out_end = offset[i + 1]
if in_begin < offset[i]:
pad_size = np.min(
[offset[i] - in_begin, offset[i + 1] - offset[i]])
if padding_trainable:
sub_w = padding_data[j:j + pad_size, :]
col[offset[i]:offset[i] + pad_size, j * M:(j + 1) *
M] = sub_w
out_begin = offset[i] + pad_size
in_begin = offset[i]
if in_end > offset[i + 1]:
pad_size = np.min(
[in_end - offset[i + 1], offset[i + 1] - offset[i]])
if padding_trainable:
sub_w = padding_data[begin_pad + context_start + j -
pad_size:begin_pad + context_start +
j, :]
col[offset[i + 1] - pad_size:offset[i + 1], j * M:(j + 1) *
M] = sub_w
in_end = offset[i + 1]
out_end = offset[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
col[out_begin:out_end, j * M:(j + 1) * M] += in_sub
return np.dot(col, filter)
class TestSeqProject(OpTest):
def setUp(self):
self.init_test_case()
......@@ -66,57 +113,9 @@ class TestSeqProject(OpTest):
'paddingTrainable': self.padding_trainable,
'contextStride': self.context_stride
}
out = np.zeros(
(self.input_size[0], self.output_represention)).astype('float32')
out = seqconv(x, self.lod, w, self.context_length, self.context_start,
self.padding_trainable, self.pad_data)
self.outputs = {'Out': out}
self.compute()
def compute(self):
x, lod = self.inputs['X']
filter = self.inputs['Filter']
pading_data = self.pad_data
out = np.zeros((self.input_size[0], self.context_length *
self.input_size[1])).astype('float32')
offset = [0]
for seq_len in lod[0]:
offset.append(offset[-1] + seq_len)
begin_pad = np.max([0, -self.context_start])
for i in range(len(offset) - 1):
for j in range(self.context_length):
in_begin = offset[i] + self.context_start + j
in_end = offset[i + 1] + self.context_start + j
out_begin = offset[i]
out_end = offset[i + 1]
if in_begin < offset[i]:
pad_size = np.min(
[offset[i] - in_begin, offset[i + 1] - offset[i]])
if self.padding_trainable:
sub_w = pading_data[j:j + pad_size, :]
out[offset[i]:offset[i] + pad_size, j * self.input_size[
1]:(j + 1) * self.input_size[1]] = sub_w
out_begin = offset[i] + pad_size
in_begin = offset[i]
if in_end > offset[i + 1]:
pad_size = np.min(
[in_end - offset[i + 1], offset[i + 1] - offset[i]])
if self.padding_trainable:
sub_w = pading_data[begin_pad + self.context_start + j -
pad_size:begin_pad +
self.context_start + j, :]
out[offset[i + 1] - pad_size:offset[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
in_end = offset[i + 1]
out_end = offset[i + 1] - pad_size
if in_end <= in_begin:
continue
in_sub = x[in_begin:in_end, :]
out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub
np.dot(out, filter, out=self.outputs['Out'])
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册