From 4466f0bec8c23558536959d06b45a1b4c2daab70 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Wed, 14 Mar 2018 16:10:54 +0100 Subject: [PATCH] MKLDNN Relu Tanh Sqrt Abs activations added --- paddle/fluid/framework/operator.h | 8 + paddle/fluid/operators/CMakeLists.txt | 5 + .../fluid/operators/activation_mkldnn_op.cc | 192 ++++++++++++++++++ paddle/fluid/operators/activation_op.cc | 52 ++++- paddle/fluid/operators/activation_op.h | 65 +++++- paddle/fluid/platform/mkldnn_helper.h | 1 + python/paddle/fluid/layer_helper.py | 2 + .../paddle/fluid/tests/unittests/op_test.py | 12 +- .../tests/unittests/test_activation_op.py | 67 ++++++ 9 files changed, 401 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/activation_mkldnn_op.cc diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 41214b41cb..d354714d0e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -84,6 +84,10 @@ class OperatorBase { return boost::get(attrs_.at(name)); } + inline bool HasAttr(const std::string& name) const { + return attrs_.count(name) != 0; + } + /// if scope is not null, also show dimensions of arguments virtual std::string DebugStringEx(const Scope* scope) const; @@ -195,6 +199,10 @@ class ExecutionContext { return op_.Attr(name); } + inline bool HasAttr(const std::string& name) const { + return op_.HasAttr(name); + } + size_t InputSize(const std::string& name) const { return op_.Inputs(name).size(); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c0245379ac..9c367dd145 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -153,7 +153,12 @@ function(op_library TARGET) # pybind USE_OP_DEVICE_KERNEL for MKLDNN if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + # Append first implemented MKLDNN activation operator + if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") + else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() endif() # pybind USE_OP diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc new file mode 100644 index 0000000000..65cf2fceb7 --- /dev/null +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -0,0 +1,192 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/operators/activation_op.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +namespace { +template +void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, + const T alpha = 0, const T beta = 0) { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto &dev_ctx = ctx.template device_context(); + const auto &mkldnn_engine = dev_ctx.GetEngine(); + + // get buffers + const auto *src = ctx.template Input("X"); + const auto *src_data = src->template data(); + + auto *dst = ctx.template Output("Out"); + const T *dst_data = dst->template mutable_data(ctx.GetPlace()); + + // get memory dim + PADDLE_ENFORCE(src->dims().size() == 4, + "Input dim must be with 4, i.e. NCHW"); + std::vector src_tz = framework::vectorize2int(src->dims()); + + // create memory description + // TODO(kbinias-intel): support more formats + auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); + + // create memory primitives + auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data); + auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data); + + auto forward_desc = mkldnn::eltwise_forward::desc( + mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); + + // save prim desc into global device context to be referred in backward path + const std::string key = ctx.op().Output("Out"); + const std::string key_eltwise_pd = key + "@eltwise_pd"; + auto forward_pd = std::make_shared( + forward_desc, mkldnn_engine); + dev_ctx.SetBlob(key_eltwise_pd, forward_pd); + + auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory); + + // push primitive to stream and wait until it's executed + std::vector pipeline = {eltwise}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} + +template +void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, + const T alpha = 0, const T beta = 0) { + auto &dev_ctx = ctx.template device_context(); + const auto &mkldnn_engine = dev_ctx.GetEngine(); + + // get buffers + const auto *x = ctx.template Input("X"); + const auto *src = x->template data(); + + auto *dout = ctx.template Input(framework::GradVarName("Out")); + const auto *diff_dst = dout->template data(); + + auto *dx = + ctx.template Output(framework::GradVarName("X")); + const T *diff_src = dx->template mutable_data(ctx.GetPlace()); + + // get memory dim + std::vector src_tz = framework::vectorize2int(x->dims()); + + // create memory description + auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); + + // create memory primitives + auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src); + auto diff_src_memory = + mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src); + auto diff_dst_memory = + mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst); + + auto backward_desc = + mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); + + // retrieve eltwise primitive desc from device context + const std::string key = ctx.op().Input("Out"); + const std::string key_eltwise_pd = key + "@eltwise_pd"; + const std::shared_ptr forward_pd = dev_ctx.GetBlob(key_eltwise_pd); + PADDLE_ENFORCE(forward_pd != nullptr, + "Fail to find eltwise_pd in device context"); + auto *p_forward_pd = + static_cast(forward_pd.get()); + + auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( + backward_desc, mkldnn_engine, *p_forward_pd); + + auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory, + diff_dst_memory, diff_src_memory); + + // push primitive to stream and wait until it's executed + std::vector pipeline = {eltwise_bwd}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} +} // anonymous namespace + +template +struct MKLDNNActivationFunc : public BaseActivationFunctor { + template + void operator()(const ExecContext &ctx) const { + eltwise_forward(ctx, algorithm); + } +}; + +template +struct MKLDNNActivationGradFunc : public BaseActivationFunctor { + template + void operator()(const ExecContext &ctx) const { + eltwise_grad(ctx, algorithm); + } +}; + +template +using ReluMkldnnFunctor = + MKLDNNActivationFunc; + +template +using TanhMkldnnFunctor = + MKLDNNActivationFunc; + +template +using SqrtMkldnnFunctor = + MKLDNNActivationFunc; + +template +using AbsMkldnnFunctor = + MKLDNNActivationFunc; + +template +using ReluMkldnnGradFunctor = + MKLDNNActivationGradFunc; + +template +using TanhMkldnnGradFunctor = + MKLDNNActivationGradFunc; + +template +using SqrtMkldnnGradFunctor = + MKLDNNActivationGradFunc; + +template +using AbsMkldnnGradFunctor = + MKLDNNActivationGradFunc; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \ + ops::MKLDNNActivationKernel>); \ + REGISTER_OP_KERNEL( \ + act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \ + ops::MKLDNNActivationGradKernel>); + +#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ + __macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor) \ + __macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor) \ + __macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor) \ + __macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor); + +FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index ec637658c0..ae9ca9d4ff 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -25,6 +25,11 @@ class ActivationOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return ActivationHelper().GetKernelType(ctx, *this); + } }; class ActivationOpGrad : public framework::OperatorWithKernel { @@ -34,6 +39,11 @@ class ActivationOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return ActivationHelper().GetKernelType(ctx, *this); + } }; class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { @@ -87,6 +97,16 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu operator"); AddOutput("Out", "Output of Relu operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Relu Activation Operator. @@ -140,6 +160,16 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Tanh operator"); AddOutput("Out", "Output of Tanh operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Tanh Activation Operator. @@ -193,6 +223,16 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sqrt operator"); AddOutput("Out", "Output of Sqrt operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Sqrt Activation Operator. @@ -208,6 +248,16 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Abs operator"); AddOutput("Out", "Output of Abs operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Abs Activation Operator. diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b95e793586..084b6bace7 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,9 +17,36 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { +class ActivationHelper { + public: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper) const { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + if (ctx.HasAttr("data_format")) { + std::string data_format = ctx.Attr("data_format"); + layout = framework::StringToDataLayout(data_format); + } + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace(), layout, library); + } +}; + template class ActivationKernel : public framework::OpKernel { @@ -49,6 +76,27 @@ class ActivationKernel } }; +template +class MKLDNNActivationKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(!context.HasAttr("X"), + "Cannot find input tensor X, variable name = %s", + context.op().Input("X")); + PADDLE_ENFORCE(!context.HasAttr("Out"), + "Cannot find output tensor Out, variable name = %s", + context.op().Output("Out")); + Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } + functor(context); + } +}; + template class ActivationGradKernel : public framework::OpKernel { @@ -77,6 +125,21 @@ class ActivationGradKernel } }; +template +class MKLDNNActivationGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } + functor(context); + } +}; + template struct BaseActivationFunctor { using ELEMENT_TYPE = T; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 90b78142b8..281d38cb8a 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -42,6 +42,7 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, } inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { + if (!ctx.HasAttr("use_mkldnn")) return false; bool use_mkldnn = ctx.Attr("use_mkldnn"); return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); } diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 58b6682271..d771837fc5 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -403,6 +403,8 @@ class LayerHelper(object): if 'use_mkldnn' in self.kwargs: act['use_mkldnn'] = self.kwargs.get('use_mkldnn') act_type = act.pop('type') + if 'use_mkldnn' in self.kwargs: + act['use_mkldnn'] = self.kwargs.get('use_mkldnn') self.append_op( type=act_type, inputs={"X": [input_var]}, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 8393f7827b..2b10f16688 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -215,7 +215,8 @@ class OpTest(unittest.TestCase): '''Fix random seeds to remove randomness from tests''' cls._np_rand_state = np.random.get_state() cls._py_rand_state = random.getstate() - + cls.use_mkldnn = False + cls.data_format = 'AnyLayout' np.random.seed(123) random.seed(124) @@ -340,7 +341,14 @@ class OpTest(unittest.TestCase): "Output (" + out_name + ") has different lod at " + str(place)) + def fill_attrs(self): + attrs = self.attrs if hasattr(self, "attrs") else dict() + attrs["use_mkldnn"] = self.use_mkldnn + attrs["data_format"] = self.data_format + return attrs + def check_output(self, atol=1e-5): + self.attrs = self.fill_attrs() places = [core.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): places.append(core.CUDAPlace(0)) @@ -348,6 +356,7 @@ class OpTest(unittest.TestCase): self.check_output_with_place(place, atol) def check_output_customized(self, checker): + self.attrs = self.fill_attrs() places = [core.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): places.append(core.CUDAPlace(0)) @@ -383,6 +392,7 @@ class OpTest(unittest.TestCase): in_place=False, max_relative_error=0.005, user_defined_grads=None): + self.attrs = self.fill_attrs() places = [core.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): places.append(core.CUDAPlace(0)) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 1e3decfbaf..c6c86a5969 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -506,5 +506,72 @@ class TestSwish(OpTest): self.check_grad(['X'], 'Out', max_relative_error=0.008) +#--------------------test MKLDNN-------------------- +class TestMKLDNNRelu(OpTest): + def setUp(self): + self.op_type = "relu" + x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Out': np.maximum(self.inputs['X'], 0)} + self.use_mkldnn = True + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.007) + + +class TestMKLDNNTanh(OpTest): + def setUp(self): + self.op_type = "tanh" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + } + self.outputs = {'Out': np.tanh(self.inputs['X'])} + self.use_mkldnn = True + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.007) + + +class TestMKLDNNSqrt(OpTest): + def setUp(self): + self.op_type = "sqrt" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + } + self.outputs = {'Out': np.sqrt(self.inputs['X'])} + self.use_mkldnn = True + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.007) + + +class TestMKLDNNAbs(OpTest): + def setUp(self): + self.op_type = "abs" + x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Out': np.abs(self.inputs['X'])} + self.use_mkldnn = True + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.007) + + if __name__ == "__main__": unittest.main() -- GitLab