diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c0245379ac481d922ee936c75bfc6b63a81be5fd..9c367dd14519e4272ee25f0ae80f81d196ceea39 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -153,7 +153,12 @@ function(op_library TARGET) # pybind USE_OP_DEVICE_KERNEL for MKLDNN if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + # Append first implemented MKLDNN activation operator + if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") + else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() endif() # pybind USE_OP diff --git a/paddle/fluid/operators/activation_mkldnn_op.cc b/paddle/fluid/operators/activation_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ff363d766db7dd97e1bc193ef7b4a095a7b7c24 --- /dev/null +++ b/paddle/fluid/operators/activation_mkldnn_op.cc @@ -0,0 +1,193 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "mkldnn.hpp" +#include "mkldnn_activation_op.h" +#include "paddle/fluid/operators/activation_op.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; + +namespace { +template +void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, + const T alpha = 0, const T beta = 0) { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto &dev_ctx = ctx.template device_context(); + const auto &mkldnn_engine = dev_ctx.GetEngine(); + + // get buffers + const auto *src = ctx.template Input("X"); + const auto *src_data = src->template data(); + + auto *dst = ctx.template Output("Out"); + const T *dst_data = dst->template mutable_data(ctx.GetPlace()); + + // get memory dim + PADDLE_ENFORCE(src->dims().size() == 4, + "Input dim must be with 4, i.e. NCHW"); + std::vector src_tz = framework::vectorize2int(src->dims()); + + // create memory description + // TODO(kbinias-intel): support more formats + auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); + + // create memory primitives + auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data); + auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data); + + auto forward_desc = mkldnn::eltwise_forward::desc( + mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); + + // save prim desc into global device context to be referred in backward path + const std::string key = ctx.op().Output("Out"); + const std::string key_eltwise_pd = key + "@eltwise_pd"; + auto forward_pd = std::make_shared( + forward_desc, mkldnn_engine); + dev_ctx.SetBlob(key_eltwise_pd, forward_pd); + + auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory); + + // push primitive to stream and wait until it's executed + std::vector pipeline = {eltwise}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} + +template +void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, + const T alpha = 0, const T beta = 0) { + auto &dev_ctx = ctx.template device_context(); + const auto &mkldnn_engine = dev_ctx.GetEngine(); + + // get buffers + const auto *x = ctx.template Input("X"); + const auto *src = x->template data(); + + auto *dout = ctx.template Input(framework::GradVarName("Out")); + const auto *diff_dst = dout->template data(); + + auto *dx = + ctx.template Output(framework::GradVarName("X")); + const T *diff_src = dx->template mutable_data(ctx.GetPlace()); + + // get memory dim + std::vector src_tz = framework::vectorize2int(x->dims()); + + // create memory description + auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, + mkldnn::memory::format::nchw); + + // create memory primitives + auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src); + auto diff_src_memory = + mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src); + auto diff_dst_memory = + mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst); + + auto backward_desc = + mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); + + // retrieve eltwise primitive desc from device context + const std::string key = ctx.op().Input("Out"); + const std::string key_eltwise_pd = key + "@eltwise_pd"; + const std::shared_ptr forward_pd = dev_ctx.GetBlob(key_eltwise_pd); + PADDLE_ENFORCE(forward_pd != nullptr, + "Fail to find eltwise_pd in device context"); + auto *p_forward_pd = + static_cast(forward_pd.get()); + + auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( + backward_desc, mkldnn_engine, *p_forward_pd); + + auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory, + diff_dst_memory, diff_src_memory); + + // push primitive to stream and wait until it's executed + std::vector pipeline = {eltwise_bwd}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} +} // anonymous namespace + +template +struct MKLDNNActivationFunc : public BaseActivationFunctor { + template + void operator()(const ExecContext &ctx) const { + eltwise_forward(ctx, algorithm); + } +}; + +template +struct MKLDNNActivationGradFunc : public BaseActivationFunctor { + template + void operator()(const ExecContext &ctx) const { + eltwise_grad(ctx, algorithm); + } +}; + +template +using ReluMkldnnFunctor = + MKLDNNActivationFunc; + +template +using TanhMkldnnFunctor = + MKLDNNActivationFunc; + +template +using SqrtMkldnnFunctor = + MKLDNNActivationFunc; + +template +using AbsMkldnnFunctor = + MKLDNNActivationFunc; + +template +using ReluMkldnnGradFunctor = + MKLDNNActivationGradFunc; + +template +using TanhMkldnnGradFunctor = + MKLDNNActivationGradFunc; + +template +using SqrtMkldnnGradFunctor = + MKLDNNActivationGradFunc; + +template +using AbsMkldnnGradFunctor = + MKLDNNActivationGradFunc; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \ + ops::MKLDNNActivationKernel>); \ + REGISTER_OP_KERNEL( \ + act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \ + ops::MKLDNNActivationGradKernel>); + +#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ + __macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor); \ + __macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor); \ + __macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor); \ + __macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor); + +FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index ec637658c03ad94624ee9a4f5def6a84387d293e..979115eee0dbe157dbcf2293d914cc250b35d22e 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/mkldnn_activation_op.h" namespace paddle { namespace operators { @@ -87,6 +88,9 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu operator"); AddOutput("Out", "Output of Relu operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Relu Activation Operator. @@ -140,6 +144,9 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Tanh operator"); AddOutput("Out", "Output of Tanh operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Tanh Activation Operator. @@ -193,6 +200,9 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sqrt operator"); AddOutput("Out", "Output of Sqrt operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Sqrt Activation Operator. @@ -208,6 +218,9 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Abs operator"); AddOutput("Out", "Output of Abs operator"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Abs Activation Operator. @@ -524,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); -REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, - ops::ActivationOpGrad); +REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad, + ops::ActivationWithMKLDNNOpGrad); -REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, - ops::ActivationOpGrad); +REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad, + ops::ActivationWithMKLDNNOpGrad); REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, tanh_shrink_grad, ops::ActivationOpGrad); @@ -536,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, softshrink_grad, ops::ActivationOpGrad); -REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, - ops::ActivationOpGrad); +REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad, + ops::ActivationWithMKLDNNOpGrad); -REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, - ops::ActivationOpGrad); +REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad, + ops::ActivationWithMKLDNNOpGrad); REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad, ops::ActivationOpGrad); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index b95e793586219b7c413d0c7adb835081874d9363..4c575b4a7b551be2d1288f7fec0a2821fc10c40d 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn_activation_op.h new file mode 100644 index 0000000000000000000000000000000000000000..083d03ebe610521c5a4beb7b977a8179700bcf40 --- /dev/null +++ b/paddle/fluid/operators/mkldnn_activation_op.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" + +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class MKLDNNActivationKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(context.Input("X") != nullptr, + "Cannot get input tensor X, variable name = %s", + context.op().Input("X")); + PADDLE_ENFORCE(context.Output("Out") != nullptr, + "Cannot find output tensor Out, variable name = %s", + context.op().Output("Out")); + Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } + functor(context); + } +}; + +template +class MKLDNNActivationGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } + functor(context); + } +}; + +namespace { +framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper) { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace(), layout, library); +} +} // anonymous namespace + +class ActivationWithMKLDNNOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this); + } +}; + +class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 58b668227168c5c5e080f3928035ad98303bbae9..d771837fc545167f7c32fcf914dd1c3c3ae64fb3 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -403,6 +403,8 @@ class LayerHelper(object): if 'use_mkldnn' in self.kwargs: act['use_mkldnn'] = self.kwargs.get('use_mkldnn') act_type = act.pop('type') + if 'use_mkldnn' in self.kwargs: + act['use_mkldnn'] = self.kwargs.get('use_mkldnn') self.append_op( type=act_type, inputs={"X": [input_var]}, diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 1e3decfbaf0691e912b96b415b68353e626cf51e..4a2b35322dd4b9718c83eb5ee679ada382938441 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -506,5 +506,54 @@ class TestSwish(OpTest): self.check_grad(['X'], 'Out', max_relative_error=0.008) +#--------------------test MKLDNN-------------------- +class TestMKLDNNRelu(TestRelu): + def setUp(self): + super(TestMKLDNNRelu, self).setUp() + + x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNTanh(TestTanh): + def setUp(self): + super(TestMKLDNNTanh, self).setUp() + + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + } + self.outputs = {'Out': np.tanh(self.inputs['X'])} + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNSqrt(TestSqrt): + def setUp(self): + super(TestMKLDNNSqrt, self).setUp() + + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") + } + self.outputs = {'Out': np.sqrt(self.inputs['X'])} + self.attrs = {"use_mkldnn": True} + + +class TestMKLDNNAbs(TestAbs): + def setUp(self): + super(TestMKLDNNAbs, self).setUp() + + x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Out': np.abs(self.inputs['X'])} + self.attrs = {"use_mkldnn": True} + + if __name__ == "__main__": unittest.main()