From f2a880421ebdfb6e0c9b2b3809d8ba8449b09ea2 Mon Sep 17 00:00:00 2001 From: Michal Gallus Date: Tue, 4 Dec 2018 10:10:02 +0100 Subject: [PATCH] Fix style @ concat integration and tests test=develop --- paddle/fluid/operators/concat_mkldnn_op.cc | 33 +++++++++---------- paddle/fluid/operators/concat_op.cc | 33 ++++++++++--------- .../tests/unittests/test_concat_mkldnn_op.py | 2 ++ 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/concat_mkldnn_op.cc b/paddle/fluid/operators/concat_mkldnn_op.cc index 37b2788d6..b8456aac9 100644 --- a/paddle/fluid/operators/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/concat_mkldnn_op.cc @@ -30,15 +30,15 @@ using platform::to_void_cast; static void EnforceLayouts(const std::vector inputs) { for (auto* input : inputs) { const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; - const bool is_format_defined = input->format() != - memory::format::format_undef; + const bool is_format_defined = + input->format() != memory::format::format_undef; PADDLE_ENFORCE(is_layout_correct && is_format_defined, "Wrong layout/format set for Input tensor"); } } -static memory::primitive_desc CreateMemPrimDesc( - const Tensor& input, const mkldnn::engine& engine) { +static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, + const mkldnn::engine& engine) { constexpr auto data_type = mkldnn::memory::f32; const auto dims = paddle::framework::vectorize2int(input.dims()); const auto format = input.format(); @@ -48,8 +48,8 @@ static memory::primitive_desc CreateMemPrimDesc( } static mkldnn::memory::format GetDstMemFormat( - const concat::primitive_desc& concat_pd) { - return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; + const concat::primitive_desc& concat_pd) { + return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; } static platform::CPUPlace GetCpuPlace( @@ -61,10 +61,9 @@ static platform::CPUPlace GetCpuPlace( } static const mkldnn::engine& GetMKLDNNEngine( - const paddle::framework::ExecutionContext& ctx) { - auto& dev_ctx = - ctx.template device_context(); - return dev_ctx.GetEngine(); + const paddle::framework::ExecutionContext& ctx) { + auto& dev_ctx = ctx.template device_context(); + return dev_ctx.GetEngine(); } template @@ -89,7 +88,7 @@ class ConcatPrimitiveFactory { memory::desc CreateDstMemDescriptor(Tensor* output) { auto dst_dims = paddle::framework::vectorize2int(output->dims()); return memory::desc(dst_dims, platform::MKLDNNGetDataType(), - memory::format::any); + memory::format::any); } mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, @@ -101,10 +100,10 @@ class ConcatPrimitiveFactory { void CreateSourcesDescriptors(const std::vector multi_input, const mkldnn::engine& mkldnn_engine) { for (size_t i = 0; i < multi_input.size(); i++) { - auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); - srcs_pd.push_back(mem_prim_desc); - srcs.push_back(memory(mem_prim_desc, - to_void_cast(multi_input[i]->data()))); + auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); + srcs_pd.push_back(mem_prim_desc); + srcs.push_back( + memory(mem_prim_desc, to_void_cast(multi_input[i]->data()))); } } @@ -134,8 +133,8 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { int64_t concat_axis = static_cast(ctx.Attr("axis")); ConcatPrimitiveFactory prim_creator; - auto concat_pd = prim_creator.CreateConcatPrimDescriptor(multi_input, - output, static_cast(concat_axis), mkldnn_engine); + auto concat_pd = prim_creator.CreateConcatPrimDescriptor( + multi_input, output, static_cast(concat_axis), mkldnn_engine); auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); stream(stream::kind::eager).submit({concat}).wait(); diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 7e58f9cde..7466107cf 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/concat_op.h" +#include #include #include -#include namespace paddle { namespace operators { @@ -63,18 +63,19 @@ class ConcatOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); - - #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { - return framework::OpKernelType(input_data_type, ctx.GetPlace(), - framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN); - } - #endif - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { @@ -82,9 +83,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddOutput("Out", "Output tensor of concat operator."); - AddAttr("use_mkldnn", - "(bool, default false) Indicates if MKL-DNN kernel will be used") - .SetDefault(false); + AddAttr( + "use_mkldnn", + "(bool, default false) Indicates if MKL-DNN kernel will be used") + .SetDefault(false); AddAttr("axis", "The axis along which the input tensors will be concatenated.") .SetDefault(0); @@ -101,7 +103,6 @@ Examples: [5,6]] )DOC"); - } }; diff --git a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py index c590687a2..0ea44c0e4 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py @@ -29,6 +29,7 @@ class TestMKLDNNConcatOp(TestConcatOp): def init_kernel_type(self): self.use_mkldnn = True + class TestMKLDNNConcatOp2(TestConcatOp2): def setUp(self): super(TestMKLDNNConcatOp2, self).setUp() @@ -40,6 +41,7 @@ class TestMKLDNNConcatOp2(TestConcatOp2): def init_kernel_type(self): self.use_mkldnn = True + class TestMKLDNNConcatOp3(TestConcatOp3): def setUp(self): super(TestMKLDNNConcatOp3, self).setUp() -- GitLab