diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3378d210cdf6a625e11b1dd5fe348aa04cdb9361..da835b33051c310064a906612ae7f9362f95c7d5 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -128,6 +128,7 @@ paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.selu ArgSpec(args=['x', 'scale', 'alpha', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.log ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 504f7e6d6c13d6c40d72a53e52fec920457f2dae..883575e41db2d883e9b969978419a10ffc58b97e 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -41,6 +41,7 @@ pass_library(seq_concat_fc_fuse_pass inference) pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) +pass_library(is_test_pass base) if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(depthwise_conv_mkldnn_pass base) @@ -62,6 +63,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..292f232ffce48593e1827fe2dfe1b8472360054e --- /dev/null +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/is_test_pass.h" +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr IsTestPass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " + "for activations and pooling."; + auto op_list = {"pool2d", "sigmoid", "logsigmoid", + "softshrink", "exp", "brelu", + "pow", "leaky_relu", "stanh", + "relu", "tanh", "tanh_shrink", + "sqrt", "abs", "ceil", + "elu", "floor", "cos", + "sin", "round", "reciprocal", + "hard_shrink", "hard_sigmoid", "relu6", + "soft_relu", "swish", "thresholded_relu", + "log", "square", "softplus", + "softsign"}; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->HasAttr("is_test")) { + op->SetAttr("is_test", true); + } else if (std::find(begin(op_list), end(op_list), op->Type()) != + end(op_list)) { + op->MutableAttrMap()->insert( + std::pair("is_test", true)); + } + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(is_test_pass, paddle::framework::ir::IsTestPass); diff --git a/paddle/fluid/framework/ir/is_test_pass.h b/paddle/fluid/framework/ir/is_test_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..99e76ca4a3de21e350e68e05e0f241937a743b9e --- /dev/null +++ b/paddle/fluid/framework/ir/is_test_pass.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class IsTestPass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/is_test_pass_tester.cc b/paddle/fluid/framework/ir/is_test_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd2cb0c9f8a8ecc41a878cd3f711713cb5c23eb3 --- /dev/null +++ b/paddle/fluid/framework/ir/is_test_pass_tester.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/is_test_pass.h" + +#include + +namespace paddle { +namespace framework { +namespace ir { + +enum class ISTEST_STATE { FALSE, TRUE, UNSET }; + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn = false, + ISTEST_STATE is_test = ISTEST_STATE::UNSET) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetAttr("name", name); + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + op->SetAttr("use_mkldnn", use_mkldnn); + if (is_test == ISTEST_STATE::UNSET) + op->MutableAttrMap()->erase("is_test"); + else if (is_test == ISTEST_STATE::FALSE) + op->SetAttr("is_test", false); + else + op->SetAttr("is_test", true); +} + +// a->pool2d->b +// b->relu->c +// c,weights1)->conv2d->d +// +// d->pool2d->e +// e->hard_sigmoid->f +// (f,weights2)->conv2d->g +// +// g->pool2d->h +// h->tanh->i +// (i,weights3)->conv2d->j +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "d", "e", "f", "g", "h", "i", + "j", "weights1", "weights2", "weights3"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights1" || v == "weights2" || v == "weights3") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "pool2d", "pooling1", std::vector({"a"}), + std::vector({"b"}), true, ISTEST_STATE::TRUE); + SetOp(&prog, "relu", "activation1", std::vector({"b"}), + std::vector({"c"}), true, ISTEST_STATE::TRUE); + SetOp(&prog, "conv2d", "conv1", std::vector({"c", "weights1"}), + std::vector({"d"}), true, ISTEST_STATE::TRUE); + + SetOp(&prog, "pool2d", "pooling2", std::vector({"d"}), + std::vector({"e"}), false, ISTEST_STATE::FALSE); + SetOp(&prog, "hard_sigmoid", "activation2", std::vector({"e"}), + std::vector({"f"}), false, ISTEST_STATE::FALSE); + SetOp(&prog, "conv2d", "conv2", std::vector({"f", "weights2"}), + std::vector({"g"}), false, ISTEST_STATE::FALSE); + + SetOp(&prog, "pool2d", "pooling3", std::vector({"g"}), + std::vector({"h"}), false, ISTEST_STATE::UNSET); + SetOp(&prog, "tanh", "activation3", std::vector({"h"}), + std::vector({"i"}), true, ISTEST_STATE::UNSET); + SetOp(&prog, "conv2d", "conv3", std::vector({"i", "weights3"}), + std::vector({"j"}), false, ISTEST_STATE::UNSET); + + return prog; +} + +TEST(IsTestPass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("is_test_pass"); + + graph = pass->Apply(std::move(graph)); + + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + auto op_name = boost::get(op->GetAttr("name")); + if (op_name == "conv3") { + ASSERT_FALSE(op->HasAttr("is_test")); + } else { + ASSERT_TRUE(op->HasAttr("is_test")); + EXPECT_TRUE(boost::get(op->GetAttr("is_test"))); + } + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(is_test_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 80658d30850aaa7212903828c5c963da5f37ca65..825bee833bf918067497f56adebbbcaf55f892a2 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -86,6 +86,7 @@ class CpuPassStrategy : public PassStrategy { "fc_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // + "is_test_pass", // }); } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 74569057913d1db9a7184ab61ba655b3a66e49bd..16a9b50e6fb174374d23cd021e47e52921871a8a 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -78,6 +78,10 @@ inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_te inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") +# mobilenet with depthwise_conv op +inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet + "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz") + # anakin if (WITH_ANAKIN AND WITH_MKL) # only needed in CI # anakin rnn1 diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc index 137bca5e2b8e2754aed274970e08b03ee816a7f2..64649b1a5e471a30f435e2b1c1a9db03d35dbd8a 100644 --- a/paddle/fluid/operators/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -71,6 +71,10 @@ class MKLDNNActivationGradKernel diff_y->format() != memory::format::format_undef, "Wrong layout/format set for Input OutGrad tensor"); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); + Functor functor; auto attrs = functor.GetAttrs(); @@ -115,11 +119,15 @@ void eltwise_forward(const framework::ExecutionContext &ctx, const std::string key_fwd = key_with_layout + "@eltwise_fwd"; const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd"; + bool is_test = ctx.Attr("is_test"); + // save input data and layout to be referred in backward path auto p_src_data = std::make_shared(x_data); - dev_ctx.SetBlob(key_src_data, p_src_data); auto p_src_layout = std::make_shared(src_format); - dev_ctx.SetBlob(key_src_layout, p_src_layout); + if (!is_test) { + dev_ctx.SetBlob(key_src_data, p_src_data); + dev_ctx.SetBlob(key_src_layout, p_src_layout); + } auto p_fwd = std::static_pointer_cast( dev_ctx.GetBlob(key_fwd)); @@ -136,14 +144,17 @@ void eltwise_forward(const framework::ExecutionContext &ctx, dev_ctx.SetBlob(key_src_mem, src_memory); // create primitive descriptor for activation forward and save it + auto mkldnn_forward_prop_kind = is_test + ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; auto forward_desc = mkldnn::eltwise_forward::desc( - mkldnn::prop_kind::forward_training, algorithm, + mkldnn_forward_prop_kind, algorithm, src_memory->get_primitive_desc().desc(), alpha, beta); auto forward_pd = std::make_shared( forward_desc, mkldnn_engine); // save prim desc into global device context to be referred in backward path - dev_ctx.SetBlob(key_fwd_pd, forward_pd); + if (!is_test) dev_ctx.SetBlob(key_fwd_pd, forward_pd); // create mkldnn memory for output y dst_memory = diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index ea260a3e92b775023085fd02eec33e6ecfaf2e81..bb9ea3f3ba08753dd23b2b2a776b7d2960e5e00e 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -22,18 +22,23 @@ namespace operators { using paddle::framework::Tensor; -#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ - class OP_NAME##OpMaker \ - : public ::paddle::framework::OpProtoAndCheckerMaker { \ - public: \ - void Make() override { \ - AddInput("X", "Input of " #OP_NAME " operator"); \ - AddOutput("Out", "Output of " #OP_NAME " operator"); \ - AddAttr("use_mkldnn", \ - "(bool, default false) Only used in mkldnn kernel") \ - .SetDefault(false); \ - AddComment(#OP_COMMENT); \ - } \ +#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ + class OP_NAME##OpMaker \ + : public ::paddle::framework::OpProtoAndCheckerMaker { \ + public: \ + void Make() override { \ + AddInput("X", "Input of " #OP_NAME " operator"); \ + AddOutput("Out", "Output of " #OP_NAME " operator"); \ + AddAttr("use_mkldnn", \ + "(bool, default false) Only used in mkldnn kernel") \ + .SetDefault(false); \ + AddAttr( \ + "is_test", \ + "(bool, default false) Set to true for inference only, false " \ + "for training. Some layers may run faster when this is true.") \ + .SetDefault(false); \ + AddComment(#OP_COMMENT); \ + } \ } #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ @@ -269,7 +274,7 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { :strong:`Softshrink Activation Operator` .. math:: - out = \begin{cases} + out = \begin{cases} x - \lambda, \text{if } x > \lambda \\ x + \lambda, \text{if } x < -\lambda \\ 0, \text{otherwise} @@ -435,7 +440,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( HardSigmoid Activation Operator. -Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), +Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), which is much faster than sigmoid. $out = \max(0, \min(1, slope * x + shift))$ diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index cf245f5038f5f5ad1b623542aa14686eff8aad32..2463c939bc5d19500ba36ba3c73db176bb82c62a 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -113,7 +113,10 @@ class BatchNormOp : public framework::OperatorWithKernel { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddAttr("is_test", "").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddAttr("momentum", "").SetDefault(0.9); AddAttr("epsilon", "") .SetDefault(1e-5) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index f2cc6642ee6c45cfd95fa3b5ccc58a4832fb8db4..c3c7c90f150a472e0f19626d71bc1c25643d0ca6 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -383,20 +383,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // create a conv primitive descriptor and save it for usage in backward std::shared_ptr conv_pd; + auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; if (bias) { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn); + conv_pd = ConvFwdPrimitiveDesc( + src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, + fuse_relu, fuse_residual_conn, fwd_prop_kind); } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn); + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, + paddings, mkldnn_engine, fuse_relu, + fuse_residual_conn, fwd_prop_kind); } // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx.SetBlob(key_conv_pd, conv_pd); + if (!is_test) dev_ctx.SetBlob(key_conv_pd, conv_pd); ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); @@ -510,14 +512,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn) const { + const bool fuse_residual_conn, + mkldnn::prop_kind fwd_prop_kind) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, - dst, stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero); + fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst, + stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_residual_conn); @@ -535,14 +537,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn) const { + const bool fuse_residual_conn, + mkldnn::prop_kind fwd_prop_kind) const { memory::dims stride_dims = {strides[0], strides[1]}; memory::dims padding_dims = {paddings[0], paddings[1]}; auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, - bias, dst, stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero); + fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst, + stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_residual_conn); @@ -587,6 +589,10 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { output_grad->format() != memory::format::format_undef, "Wrong layout/format set for output_grad tensor"); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); + if (!input_grad && !filter_grad) return; std::vector strides = ctx.Attr>("strides"); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 4d370746382a4247f51aafa189e86eece941c320..1ac4bef615a546ec97aefd83653c649f187caba0 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -109,7 +109,10 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } void Conv2DOpMaker::Make() { - AddAttr("is_test", "").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddInput( "Input", "(Tensor) The input tensor of convolution operator. " diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 3c28ef30922e6d6ba09b96282619eef15867631e..dd3474dd2529b5e2cb2cd32aec41fb6357b5d537 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -49,7 +49,10 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { PADDLE_ENFORCE(drop_p >= 0.0f && drop_p <= 1.0f, "'dropout_prob' must be between 0.0 and 1.0."); }); - AddAttr("is_test", "True if in test phase.").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddAttr("fix_seed", "A flag indicating whether to use a fixed seed to generate " "random mask. NOTE: DO NOT set this flag to true in " diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index e608eba05d5680254835f7b25f53d6a59e310e2a..43af83fd693b433337bdc80188bd0568f76b3e66 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -138,7 +138,7 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( FakeQuantize operator -$$scale = max(abs(X))$$ +$$scale = max(abs(X))$$ $$range = 2^{bit_length - 1} - 1$$ $$Out = round(X/scale * range)$$ @@ -199,11 +199,14 @@ class FakeQuantizeRangeAbsMaxOpMaker PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, "'bit_length' should be between 1 and 16."); }); - AddAttr("is_test", "").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddComment(R"DOC( FakeQuantize operator is used in static quantization. -$$scale = max(max(abs(x)), history_abs_max)$$ +$$scale = max(max(abs(x)), history_abs_max)$$ $$range = 2^{bit_length - 1} - 1$$ $$Out = round(X/scale * range)$$ diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 61c3cb34a2472c0ba7d2a7ea5abf8e826a793951..a3bb2be5c7af5b85fa9785c5e64ac314feda8b78 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -46,7 +46,7 @@ struct LRNFunctor { int pre_pad = (n - 1) / 2; // compute batches one by one for (int i = 0; i < N; ++i) { - blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size); + blas.VSQUARE(fea_size, idata + i * fea_size, sdata + pre_pad * img_size); // init the first channel of mid for (int c = 0; c < n; ++c) { blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size); @@ -229,8 +229,8 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "the input will be transformed automatically. ") .SetDefault("AnyLayout"); AddAttr("is_test", - "Turns on memory optimization that optimizes away " - "unnecessary memory allocations. Used by MKLDNN.") + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") .SetDefault(false); AddComment(R"DOC( diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 5d0d562030d2a20e4a1cefd3c36c6533fd35dc96..6734df1530893777fca3ccf66b1e8aab40e41cfc 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -153,7 +153,7 @@ class Blas { void VEXP(int n, const T* x, T* y) const; template - void VSQR(int n, const T* x, T* y) const; + void VSQUARE(int n, const T* x, T* y) const; template void VPOW(int n, const T* x, T alpha, T* y) const; @@ -245,8 +245,8 @@ class BlasT : private Blas { } template - void VSQR(ARGS... args) const { - Base()->template VSQR(args...); + void VSQUARE(ARGS... args) const { + Base()->template VSQUARE(args...); } template diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 59454669be9e0f92a6fc0db52445307d88e1c7d8..93bf7c7c88db36807143b136ea800d6e5e49dd43 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -105,7 +105,7 @@ struct CBlas { } template - static void VSQR(ARGS... args) { + static void VSQUARE(ARGS... args) { platform::dynload::vsSqr(args...); } @@ -195,7 +195,7 @@ struct CBlas { } template - static void VSQR(ARGS... args) { + static void VSQUARE(ARGS... args) { platform::dynload::vdSqr(args...); } @@ -262,7 +262,9 @@ struct CBlas { } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } - static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); } + static void VSQUARE(...) { + PADDLE_THROW("float16 VSQUARE not supported on CPU"); + } static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; @@ -423,12 +425,12 @@ void Blas::VEXP(int n, const T *x, T *y) const { template <> template -void Blas::VSQR(int n, const T *x, T *y) const { +void Blas::VSQUARE(int n, const T *x, T *y) const { #ifdef PADDLE_WITH_MKLML - CBlas::VSQR(n, x, y); + CBlas::VSQUARE(n, x, y); #else for (int i = 0; i < n; ++i) { - y[i] = std::sqrt(x[i]); + y[i] = x[i] * x[i]; } #endif } diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index 56cef91e29cc7da27384c27a7ec63e90cfadfc3b..0a9a29956affedb8605ab9949070943fbbb54145 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -87,6 +87,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); + bool is_test = ctx.Attr("is_test"); if (ctx.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; @@ -142,16 +143,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr pool_pd = CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, padding_right_bottom, ksize, pooling_type, - mkldnn_engine, ceil_mode); + mkldnn_engine, ceil_mode, is_test); // save pool_pd into global device context to be referred in backward path - dev_ctx.SetBlob(key_pool_pd, pool_pd); - - std::shared_ptr workspace_memory = - CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); - - // save pool_workspace_memory to be referred in backward path - dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); + if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); auto src_memory = std::make_shared(pool_pd->src_primitive_desc(), to_void_cast(input_data)); @@ -161,9 +156,19 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); - pool_p = std::make_shared(*pool_pd, *(src_memory.get()), - *(dst_memory.get()), - *workspace_memory); + if (is_test) { + pool_p = std::make_shared(*pool_pd, *src_memory, + *dst_memory); + } else { + std::shared_ptr workspace_memory = + CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); + + // save pool_workspace_memory to be referred in backward path + dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); + + pool_p = std::make_shared( + *pool_pd, *src_memory, *dst_memory, *workspace_memory); + } dev_ctx.SetBlob(key_pool_p, pool_p); @@ -201,9 +206,12 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& stride, const std::vector& padding_left_top, const std::vector& padding_right_bot, const std::vector& kernel, const std::string& pooling_type, const mkldnn::engine& engine, - bool ceil_mode) const { + bool ceil_mode, bool is_test) const { + auto mkldnn_forward_prop_kind = is_test + ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; auto pool_desc = mkldnn::pooling_forward::desc( - mkldnn::prop_kind::forward, + mkldnn_forward_prop_kind, pooling_type == "max" ? mkldnn::algorithm::pooling_max : mkldnn::algorithm::pooling_avg, src, dst, stride, kernel, padding_left_top, padding_right_bot, @@ -248,6 +256,10 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { out_grad->format() != memory::format::format_undef, "Wrong layout/format set for Input output_grad tensor"); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); + std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 46a95350a7293c18313811ba9b367fd65955145a..52b607df74446866c535751f3faa11765cb6f247 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -206,6 +206,11 @@ void Pool2dOpMaker::Make() { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + // TODO(dzhwinter): need to registered layout transform function AddComment(R"DOC( diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..67fca18000a4fac1e2ca39fc26ebe67649a51bc3 --- /dev/null +++ b/paddle/fluid/operators/selu_op.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/selu_op.h" +#include + +namespace paddle { +namespace operators { + +class SeluOp : public framework::OperatorWithKernel { + public: + SeluOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SeluOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SeluOp should not be null."); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.GetPlace()); + } +}; + +class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + +class SeluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor of selu operator."); + AddOutput("Out", "The output tensor of selu operator."); + AddAttr("scale", + "(float) the default value is 1.0507~. For more " + "information about this value, please refer to:" + "https://arxiv.org/abs/1706.02515.") + .SetDefault(1.0507009873554804934193349852946); + AddAttr("alpha", + "(float) the default value is 1.6732~. For more " + "information about this value, please refer to:" + "https://arxiv.org/abs/1706.02515.") + .SetDefault(1.6732632423543772848170429916717); + AddComment(R"DOC( +Selu Operator. + +The equation is: +$$ +f(x) =\lambda* +\begin{cases} + \quad \quad x, \quad \quad \quad \text{if} \ x > 0 \\ + \alpha * e^x - \alpha, \qquad \text{if} \ x <= 0 +\end{cases} +$$ + +The input `X` can carry the LoD (Level of Details) information, +or not. And the output shares the LoD information with input `X`. +)DOC"); + } +}; + +class SeluGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDesc(); + grad_op->SetType("selu_grad"); + grad_op->SetInput("Out", Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + +class SeluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null"); + auto x_grad_name = framework::GradVarName("X"); + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Out")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::GetDataTypeOfVar(ctx.InputVar("Out")), ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType, + ops::SeluGradMaker); +REGISTER_OPERATOR(selu_grad, ops::SeluGradOp); +REGISTER_OP_CPU_KERNEL( + selu, ops::SeluKernel, + ops::SeluKernel); +REGISTER_OP_CPU_KERNEL( + selu_grad, ops::SeluGradKernel, + ops::SeluGradKernel); diff --git a/paddle/fluid/operators/selu_op.cu b/paddle/fluid/operators/selu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..fb3245ab7609ea9067709134a3713e9871dbb4d4 --- /dev/null +++ b/paddle/fluid/operators/selu_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/selu_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + selu, ops::SeluKernel, + ops::SeluKernel); +REGISTER_OP_CUDA_KERNEL( + selu_grad, ops::SeluGradKernel, + ops::SeluGradKernel); diff --git a/paddle/fluid/operators/selu_op.h b/paddle/fluid/operators/selu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bdb506885c932708803fe8d84ee705aee0fe02b4 --- /dev/null +++ b/paddle/fluid/operators/selu_op.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" +namespace paddle { +namespace operators { + +static HOSTDEVICE float real_exp(float x) { return expf(x); } +static HOSTDEVICE float real_exp(double x) { return exp(x); } + +template +struct SeluFunctor { + SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) + : x_data_ptr_(x_data_ptr), + alpha_(alpha), + scale_(scale), + y_data_ptr_(y_data_ptr) {} + + HOSTDEVICE void operator()(size_t idx) const { + T x_ele = x_data_ptr_[idx]; + if (x_ele <= 0) { + x_ele = alpha_ * real_exp(x_ele) - alpha_; + } + y_data_ptr_[idx] = scale_ * x_ele; + } + const T* x_data_ptr_; + const float alpha_; + const float scale_; + T* y_data_ptr_; +}; + +template +struct SeluGradFunctor { + SeluGradFunctor(const T* y_data_ptr, const T* dy_data_ptr, float alpha, + float scale, T* dx_data_ptr) + : y_data_ptr_(y_data_ptr), + dy_data_ptr_(dy_data_ptr), + alpha_(alpha), + scale_(scale), + la_(alpha * scale), + dx_data_ptr_(dx_data_ptr) {} + + HOSTDEVICE void operator()(size_t idx) const { + T y_ele = y_data_ptr_[idx]; + T dy_ele = dy_data_ptr_[idx]; + + float tmp = scale_; + if (y_ele <= 0) { + tmp = y_ele + la_; + } + dx_data_ptr_[idx] = dy_ele * tmp; + } + const T* y_data_ptr_; + const T* dy_data_ptr_; + const float alpha_; + const float scale_; + const float la_; + T* dx_data_ptr_; +}; + +template +class SeluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using Tensor = framework::Tensor; + + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + + float alpha = context.Attr("alpha"); + float scale = context.Attr("scale"); + + auto out_ptr = out->mutable_data(context.GetPlace()); + + SeluFunctor functor(x->data(), alpha, scale, out_ptr); + + auto& dev_ctx = context.template device_context(); + size_t limit = static_cast(x->numel()); + platform::ForRange for_range(dev_ctx, limit); + for_range(functor); + } +}; + +template +class SeluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using Tensor = framework::Tensor; + + auto* out = context.Input("Out"); + auto* dout = context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + + float alpha = context.Attr("alpha"); + float scale = context.Attr("scale"); + + auto dx_ptr = dx->mutable_data(context.GetPlace()); + + SeluGradFunctor functor(out->data(), dout->data(), alpha, scale, + dx_ptr); + + auto& dev_ctx = context.template device_context(); + size_t limit = static_cast(out->numel()); + platform::ForRange for_range(dev_ctx, limit); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/sequence_pool_op.cc b/paddle/fluid/operators/sequence_pool_op.cc index 217bb1610fd3f02f0f72d3b7750ebcdfad243f48..7e80b8db5e90730e2df420466a33362620e15730 100644 --- a/paddle/fluid/operators/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_pool_op.cc @@ -47,7 +47,10 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) This tensor is used for the sequence max-pooling " "to record the max indexes.") .AsIntermediate(); - AddAttr("is_test", "").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddAttr( "pooltype", "(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.") diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 9e21b6c824bfd7d1c1090e5ba3ba2f6aa9bdb230..091ce4e6e8e2c3c6e2f064c1cfcae222af8299e0 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -96,20 +96,21 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); AddAttr("is_test", - "Disable epsilon adding to softmax results. Used by MKLDNN.") + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") .SetDefault(false); AddComment(R"DOC( Softmax Operator. -The input of the softmax operator is a tensor of any rank. The output tensor +The input of the softmax operator is a tensor of any rank. The output tensor has the same shape as the input. -The input tensor will first be logically flattened to a 2-D matrix. The matrix's -second dimension(row length) is as same as the last dimension of the input -tensor, and the first dimension(column length) is the product of all other -dimensions of the input tensor. For each row of the matrix, the softmax operator -squashes the K-dimensional(K is the width of the matrix, which is also the size -of the input tensor's last dimension) vector of arbitrary real values to a +The input tensor will first be logically flattened to a 2-D matrix. The matrix's +second dimension(row length) is as same as the last dimension of the input +tensor, and the first dimension(column length) is the product of all other +dimensions of the input tensor. For each row of the matrix, the softmax operator +squashes the K-dimensional(K is the width of the matrix, which is also the size +of the input tensor's last dimension) vector of arbitrary real values to a K-dimensional vector of real values in the range [0, 1] that add up to 1. It computes the exponential of the given dimension and the sum of exponential values of all the other dimensions in the K-dimensional vector input. diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index aa6af055decc4856fcf2036d324af6b1ff3a5de0..2b56514fe086dd411fcf842e7e7acba4edf98990 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -92,7 +92,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { "variables generated in the i'th step."); AddAttr(kStepBlock, "The step block inside WhileOp"); - AddAttr("is_test", "True if in test phase.").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddComment(R"DOC( )DOC"); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1b5009e76126e8b0fbe2805cc85f4f133a70c1ae..f60f3731636d805be06319e3f082cbed35be4940 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -110,6 +110,7 @@ __all__ = [ 'random_crop', 'mean_iou', 'relu', + 'selu', 'log', 'crop', 'rank_loss', @@ -6182,6 +6183,47 @@ def relu(x, name=None): return out +@templatedoc() +def selu(x, scale=None, alpha=None, name=None): + """ + ${comment} + + Args: + x (Variable): The input tensor. + scale(float, None): If the scale is not set, + the default value is 1.0507009873554804934193349852946. + For more information about this value, please refer + to: https://arxiv.org/abs/1706.02515. + alpha(float, None): If the alpha is not set, + the default value is 1.6732632423543772848170429916717. + For more information about this value, please refer + to: https://arxiv.org/abs/1706.02515. + name (str|None, default None): A name for this layer If set None, + the layer will be named automatically. + + Returns: + Variable: The output tensor with the same shape as input. + + Examples: + + .. code-block:: python + + output = fluid.layers.selu(x) + """ + helper = LayerHelper('selu', **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + attrs = {} + if scale is not None: + attrs["scale"] = scale + if alpha is not None: + attrs["alpha"] = alpha + + helper.append_op( + type="selu", inputs={"X": x}, outputs={"Out": out}, attrs=attrs) + return out + + def mean_iou(input, label, num_classes): """ Mean Intersection-Over-Union is a common evaluation metric for diff --git a/python/paddle/fluid/tests/unittests/test_selu_op.py b/python/paddle/fluid/tests/unittests/test_selu_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bcba0511da747990b1e99026048c7ce95140a422 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_selu_op.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import six +from op_test import OpTest + + +class SeluTest(OpTest): + def setUp(self): + self.op_type = "selu" + self.x_shape = [3, 5, 5, 10] + self.dtype = np.float32 + self.init_x_shape() + self.init_dtype() + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + x = np.random.normal(size=self.x_shape).astype(self.dtype) + + # Since zero point in selu is not differentiable, avoid randomize + # zero. + x[np.abs(x) < 0.005] = 0.02 + + x_flat = x.flatten() + + for i in range(x_flat.size): + if x_flat[i] < 0: + x_flat[i] = alpha * np.exp(x_flat[i]) - alpha + x_flat[i] = scale * x_flat[i] + + out_np = x_flat.reshape(self.x_shape) + + self.inputs = {'X': x} + self.outputs = {'Out': out_np} + + self.attrs = { + 'alpha': alpha, + 'scale': scale, + } + + def init_x_shape(self): + pass + + def init_dtype(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 9a13cecc646e8534a157fad882fd97836348deb4..ccf7af334d091dff5cf6a94e5875cdf582fa0c5c 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -73,6 +73,38 @@ class InferenceTranspiler(object): program) # ResNet residual block merging self._fuse_bn_relu_mkldnn(program) + self._is_test_pass(program) + + def _is_test_pass(self, program): + ''' + Transpile the program setting is_test = true for all layers and + inserts is_test attribute to pooling and activation layers. + As a result some operators might run faster + :param program: program to transpile + :type program: Program + ''' + self.block = program.block(0) + + i = 0 + while i < len(self.block.ops): + current_op = self.block.ops[i] + if current_op.has_attr("is_test"): + current_op._set_attr("is_test", True) + elif current_op.type in [ + "pool2d", "sigmoid", "logsigmoid", "softshrink", "exp", + "brelu", "pow", "leaky_relu", "stanh", "relu", "tanh", + "tanh_shrink", "sqrt", "abs", "ceil", "elu", "floor", "cos", + "sin", "round", "reciprocal", "hard_shrink", "hard_sigmoid", + "relu6", "soft_relu", "swish", "thresholded_relu", "log", + "square", "softplus", "softsign" + ]: + current_op._set_attr("is_test", True) + i = i + 1 + # TODO(luotao): use clone() method to flush the program.desc in force, + # since some large program.desc will not be flushed immediately. + # And a better solution will be considered later. + program = program.clone() + def _depthwise_conv_mkldnn(self, program): ''' Transpile the program by replacing depthwise_conv2d to conv2d for MKLDNN program. diff --git a/python/requirements.txt b/python/requirements.txt index 7a24dd519afb0d0bc84e89da9f6b2ab2fa8718ce..84cf440397b994ba12fa70d9e316e788f34e2415 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,7 +1,7 @@ requests==2.9.2 numpy>=1.12,<=1.14 #TODO:change to ">=1.12" when numpy fix bug in 1.15 and higher version protobuf==3.1 -recordio>=0.1.0; sys_platform != 'win32' +recordio>=0.1.0 matplotlib==2.2.3 # TODO: let python3 paddlepaddle package use latest matplotlib rarfile scipy>=0.19.0