diff --git a/cmake/tensorrt.cmake b/cmake/tensorrt.cmake index ac19b1651893f18b14c62a0986df75bed25d7e80..8f65a737c43a124c05574d6eb9c3050fdab5299a 100644 --- a/cmake/tensorrt.cmake +++ b/cmake/tensorrt.cmake @@ -16,7 +16,9 @@ find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a DOC "Path to TensorRT library.") if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY) + if(WITH_DSO) set(TENSORRT_FOUND ON) + endif(WITH DSO) else() set(TENSORRT_FOUND OFF) endif() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 371384dc56eec91db1f621c0ebb65113e7a5a5cc..1a8d9cefbfa570d2ac3f4fc32d50d705ddc67a75 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -429,7 +429,7 @@ struct LSTM : public PatternBase { struct GRU : public PatternBase { GRU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "lstm") {} + : PatternBase(pattern, name_scope, "gru") {} PDNode* operator()(PDNode* x); diff --git a/paddle/fluid/inference/api/api.cc b/paddle/fluid/inference/api/api.cc index 5f1e1b548c7b7daa66932571d7053701bc0bd1f6..c71769a32f604358fe68c927546591310649f116 100644 --- a/paddle/fluid/inference/api/api.cc +++ b/paddle/fluid/inference/api/api.cc @@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -64,13 +64,15 @@ PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) { void PaddleBuf::Resize(size_t length) { // Only the owned memory can be reset, the external memory can't be changed. - if (length_ == length) return; + if (length_ >= length) return; if (memory_owned_) { Free(); + data_ = malloc(length); + length_ = length; + memory_owned_ = true; + } else { + PADDLE_THROW("The memory is allocated externally, can not Resized"); } - data_ = new char[length]; - length_ = length; - memory_owned_ = true; } void PaddleBuf::Reset(void* data, size_t length) { @@ -82,8 +84,8 @@ void PaddleBuf::Reset(void* data, size_t length) { void PaddleBuf::Free() { if (memory_owned_ && data_) { - assert(length_ > 0); - delete[] static_cast(data_); + PADDLE_ENFORCE_GT(length_, 0); + free(static_cast(data_)); data_ = nullptr; length_ = 0; } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index e8c34047abaac1552f140262f5fdb3f11c792bfc..e397457061662c8afb9760ef52406c22caaeb213 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -53,7 +53,7 @@ set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classifi download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz") inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta + ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/model --infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt) # ocr diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 5385bcdaec3f9e86cf0d7bf33f270e48ae186fba..eae65968285703f5882d910e29bc5d8e1511cba6 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); bool fuse_relu = ctx.Attr("fuse_relu"); + bool fuse_eltwise = ctx.Attr("fuse_eltwise"); int groups = ctx.Attr("groups"); // TODO: add support for dilation @@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, - paddings, mkldnn_engine, fuse_relu); + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + strides, paddings, mkldnn_engine, + fuse_relu, fuse_eltwise); } else { - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, - paddings, mkldnn_engine, fuse_relu); + conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_eltwise); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } private: - mkldnn::primitive_attr AddRelu() const { - // Fusion with ReLU layer is executed through the PostOps feature. Create a - // PostOps object and configure it to execute an eltwise relu operation. + mkldnn::primitive_attr CreatePostOps(bool fuse_relu, + bool fuse_eltwise) const { mkldnn::primitive_attr conv_attr; - constexpr float scale = 1.0f; - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 0.0f; mkldnn::post_ops post_operations; - post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, - negative_slope, placeholder); + // Fusion with Elementwise layer relies on adding a sum post-operation with + // the scale parameter. It is assumed that when fuse_eltwise is true, the + // Output tensor contains the data coming from residual connection. The + // result of this post_op is: Output = scale * Output + Conv_Out. + if (fuse_eltwise) { + post_operations.append_sum(1.0f); + } + // Fusion with ReLU layer is executed through the PostOps feature. Create a + // PostOps object and configure it to execute an eltwise relu operation. + if (fuse_relu) { + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + } conv_attr.set_post_ops(post_operations); return conv_attr; } @@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine, - const bool fuse_relu) const { + const mkldnn::engine& engine, const bool fuse_relu, + const bool fuse_eltwise) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; @@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr; - if (fuse_relu) { - conv_attr = AddRelu(); - } + mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); @@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& bias, const memory::desc& dst, const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine, - const bool fuse_relu) const { + const mkldnn::engine& engine, const bool fuse_relu, + const bool fuse_eltwise) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; @@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr; - if (fuse_relu) { - conv_attr = AddRelu(); - } + mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 41d4fcf6de7c8fcb3cfbb2063b0a2ac1a2356168..8f84bf71a7f77606bed6672f0830e3fc80165a42 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() { .SetDefault(false); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_eltwise", + "(bool, default false) Only used in mkldnn kernel. Used " + "whenever convolution output is connected via skip connection " + "to a previous layer.") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index e22bc552f85b85c75f06b4158f2abac2d3843256..13682b78f0eccf049daa315f3a26aafd22e42a41 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -125,7 +125,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { + framework::AsyncIO([var_name_val, s, this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -166,7 +166,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, s->Prepare(h, time_out); framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, - time_out, s, this] { + s, this] { auto* var = p_scope->FindVar(in_var_name_val); ::grpc::ByteBuffer req; diff --git a/paddle/fluid/operators/distributed/proto_encoder_helper.h b/paddle/fluid/operators/distributed/proto_encoder_helper.h index 2fab02e32fe18ee04f86a69bb5bae1cbe7c6762c..d2b0eb6ca6de1984dc7cfc2a662c88d5e56e1e05 100644 --- a/paddle/fluid/operators/distributed/proto_encoder_helper.h +++ b/paddle/fluid/operators/distributed/proto_encoder_helper.h @@ -82,8 +82,10 @@ class ProtoEncodeHelper { : base_(buf), p_(buf), limit_(base_ + max_size) {} ~ProtoEncodeHelper() { +#define REPLACE_ENFORCE_GLOG 1 // Make sure callers didn't do operations that went over max_size promised - PADDLE_ENFORCE_LE(p_, limit_); + paddle::platform::throw_on_error(p_ <= limit_); +#undef REPLACE_ENFORCE_GLOG } const char* data() const { return base_; } diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 966d78b84130c172c41e8049bf6bb1dc659d7d48..dc008d16971bc762b401ddece56f9ec56f7a47d6 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -59,17 +59,16 @@ static void ParallelExecuteBlocks( framework::ProgramDesc *program, framework::Scope *scope) { std::vector> fs; for (size_t idx : parallel_blkids) { - fs.push_back( - framework::Async([&executor, &prepared, &program, &scope, idx]() { - int run_block = idx; // thread local - try { - VLOG(3) << "running server block: " << run_block - << "pointer: " << prepared[run_block].get(); - executor->RunPreparedContext(prepared[run_block].get(), scope); - } catch (const std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); + fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() { + int run_block = idx; // thread local + try { + VLOG(3) << "running server block: " << run_block + << "pointer: " << prepared[run_block].get(); + executor->RunPreparedContext(prepared[run_block].get(), scope); + } catch (const std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); } for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index f25d3d3f1ee1f89d46b8e7c88ca68048f5203544..69318a6598c8c69eceab7216df6382537153d34f 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -103,6 +103,58 @@ class MaxSeqPoolGradFunctor { } }; +template +class LastSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, + framework::Tensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Point to the begin of next sequence + in_data += seq_len * item_size; + // Copy the last item of sequence to output + std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + out_data += item_size; + } + } +}; + +template +class FirstSeqPoolFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::LoDTensor& input, + framework::Tensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod = input.lod()[0]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + // Copy the first item of sequence to output + std::memcpy(out_data, in_data, item_size * sizeof(T)); + // Point to the next sequence + in_data += seq_len * item_size; + out_data += item_size; + } + } +}; + template class SequencePoolFunctor { public: @@ -116,6 +168,16 @@ class SequencePoolFunctor { max_pool(context, input, output, index); return; } + if (pooltype == "LAST") { + math::LastSeqPoolFunctor last_pool; + last_pool(context, input, output); + return; + } + if (pooltype == "FIRST") { + math::FirstSeqPoolFunctor first_pool; + first_pool(context, input, output); + return; + } auto lod = input.lod()[0]; auto& place = *context.eigen_device(); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { @@ -133,10 +195,6 @@ class SequencePoolFunctor { } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); - } else if (pooltype == "LAST") { - out_e.device(place) = in_e.chip(h - 1, 0); - } else if (pooltype == "FIRST") { - out_e.device(place) = in_e.chip(0, 0); } else { PADDLE_THROW("unsupported pooling pooltype"); } diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 23d9ea88f6701f9f9e5e02948e996878a849ddd6..e0c4c81bdd5b5d0af3bafe632a2fa033efd08050 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel { std::string mode = ctx->Attrs().Get("mode"); auto x_dim = ctx->GetInputDim("X"); - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of PreluOp should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Alpha"), + "Input(Alpha) of PreluOp should not be null"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PreluOp should not be null"); if (mode == "all") { PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, "For mode 'all', size of weight Alpha must be one."); diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index ad095b92711dccb44f26748bcfa89a0b4123c6e7..ba5065f468376d83488d7eade5dc2041d86dfd39 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -33,6 +33,7 @@ function print_usage() { ${BLUE}single_test${NONE}: run a single unit test ${BLUE}bind_test${NONE}: parallel tests bind to different GPU ${BLUE}doc${NONE}: generate paddle documents + ${BLUE}gen_doc_lib${NONE}: generate paddle documents library ${BLUE}html${NONE}: convert C++ source code into HTML ${BLUE}dockerfile${NONE}: generate paddle release dockerfile ${BLUE}capi${NONE}: generate paddle CAPI package @@ -431,24 +432,60 @@ EOF linkchecker doc/v2/cn/html/index.html linkchecker doc/v2/api/en/html/index.html - if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi; +# if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi; +# +# # Deploy to the the content server if its a "develop" or "release/version" branch +# # The "develop_doc" branch is reserved to test full deploy process without impacting the real content. +# if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then +# PPO_SCRIPT_BRANCH=develop +# elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then +# PPO_SCRIPT_BRANCH=master +# else +# # Early exit, this branch doesn't require documentation build +# return 0; +# fi +# # Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch +# export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh +# export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python +# cd .. +# curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH} +# cd - +} - # Deploy to the the content server if its a "develop" or "release/version" branch - # The "develop_doc" branch is reserved to test full deploy process without impacting the real content. - if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then - PPO_SCRIPT_BRANCH=develop - elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then - PPO_SCRIPT_BRANCH=master - else - # Early exit, this branch doesn't require documentation build - return 0; - fi - # Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch - export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh - export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python - cd .. - curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH} - cd - +function gen_doc_lib() { + mkdir -p ${PADDLE_ROOT}/build + cd ${PADDLE_ROOT}/build + cat < %s" % (orig_path, head_path)) class TestDistTransformer2x2Sync(TestDistBase): diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index adad2428f7fdc554cf4efd652f52b5c5de0ab527..49ba2cfd55bc881ed753fcefbd41f5b8fd4ebaf7 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -65,8 +65,43 @@ class InferenceTranspiler(object): if use_mkldnn: self._fuse_conv_bias_mkldnn(program) self._fuse_conv_relu_mkldnn(program) + self._fuse_conv_eltwise_mkldnn(program) + self._fuse_conv_relu_mkldnn( + program) # ResNet residual block merging self._fuse_bn_relu_mkldnn(program) + def _fuse_conv_eltwise_mkldnn(self, program): + ''' + Transpile the program fusing elementwise_add into conv for MKLDNN + program. Elementwise add following convolution OP can be fused by adding + 'fuse_eltwise' attribute to convolution OP and replacing its output + Tensor with second parameter of elementwise_add. + The result of fuse is: + - before: + - conv->elementwise_add->any_other_op + - after: + - conv->any_other_op + :param program: program to transpile + :type program: Program + ''' + self.block = program.block(0) + + i = 0 + while i < len(self.block.ops): + current_op = self.block.ops[i] + if current_op.type in ['conv2d']: + next_op = self.block.ops[i + 1] + if next_op.type == 'elementwise_add': + self._fuse_conv_eltwise(current_op, next_op) + self.block._remove_op(i + 1) # Remove elementwise_add + i = i + 1 + self._adjust_input() + self._remove_unused_var() + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + program = program.clone() + def _fuse_conv_relu_mkldnn(self, program): ''' Transpile the program by fused relu activation for MKLDNN program. @@ -88,9 +123,9 @@ class InferenceTranspiler(object): if current_op.type in ['conv2d']: next_op = self.block.ops[i + 1] if next_op.type == 'relu': - # modify conv OP to include relu + # modify bnorm OP to include relu current_op.set_attr("fuse_relu", True) - # remove conv OP + # remove relu OP self.block._remove_op(i + 1) i = i + 1 @@ -409,6 +444,20 @@ class InferenceTranspiler(object): outputs={"Output": out_var}, attrs=attrs) + def _fuse_conv_eltwise(self, conv_op, eltwise_op): + ''' + fuse the conv op with elementwise_add + + :param conv_op: convolution operator + :type conv_op: Operator + :param eltwise_op: operator adding data from skip connection + :type eltwise_op: Operator + ''' + + conv_op.set_attr("fuse_eltwise", True) + self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0] + self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0] + def _adjust_input(self): for i in range(len(self.block.ops)): current_op = self.block.ops[i]