diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 20da74eca4ef879c8872703b65d86d5eed941bb5..4f1080952a11e393fa0d82dcfad949e82c22ee9b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2102,7 +2102,7 @@ PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = std::unordered_set( - {"concat", "conv2d", "fusion_gru", "reshape2", "transpose2"}); + {"concat", "conv2d", "fusion_gru", "reshape2", "transpose2", "sum"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc index 146e29249b7c610b8df9df17838b9db232e62fb8..4ca9724026a9cfa44b6a1b4ac53f3e1643eae4d3 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc @@ -44,6 +44,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("X", {inputs[0]}); } else if (type == "reshape2") { op->SetInput("X", {inputs[0]}); + } else if (type == "sum") { + op->SetInput("X", {inputs[0], inputs[1]}); } else { FAIL() << "Unexpected operator type."; } @@ -61,8 +63,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ProgramDesc BuildProgramDesc() { ProgramDesc prog; - for (auto& v : std::vector( - {"a", "b", "c", "f", "g", "h", "k", "l", "m", "n", "o", "p"})) { + for (auto& v : + std::vector({"a", "b", "c", "f", "g", "h", "k", "l", "m", + "n", "o", "p", "r", "s"})) { prog.MutableBlock(0)->Var(v); } @@ -75,6 +78,7 @@ ProgramDesc BuildProgramDesc() { SetOp(&prog, "concat", "concat2", {"l", "m"}, {"n"}); SetOp(&prog, "transpose2", "transpose", {"n"}, {"o"}); SetOp(&prog, "reshape2", "reshape", {"o"}, {"p"}); + SetOp(&prog, "sum", "sum", {"p", "r"}, {"s"}); return prog; } @@ -122,7 +126,7 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { } TEST(Bfloat16PlacementPass, enable_all) { - MainTest({"conv2d", "pool2d", "relu", "concat"}, 7); + MainTest({"conv2d", "pool2d", "relu", "concat", "sum"}, 8); } TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { @@ -130,7 +134,7 @@ TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { MainTest({"conv2d", "pool2d"}, 3); } -TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(5); } +TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(6); } } // namespace ir } // namespace framework diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc index 3d3738d922f77b067341b8f68e3d70a040832d3a..4df7818072f0538305808cd14606ae45ea84238d 100644 --- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc @@ -40,13 +40,6 @@ class MKLDNNDeviceContext; namespace paddle { namespace operators { -using framework::DataLayout; -using mkldnn::memory; -using mkldnn::primitive; -using mkldnn::reorder; -using mkldnn::stream; -using mkldnn::sum; -using paddle::framework::Tensor; using paddle::platform::CPUDeviceContext; using paddle::platform::MKLDNNDeviceContext; using platform::to_void_cast; @@ -71,21 +64,21 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT { auto dst_tz = framework::vectorize(z->dims()); auto src_tz = dst_tz; - std::vector srcs_md; + std::vector srcs_md; for (size_t i = 0; i < in_vars.size(); i++) { auto& input_it = in_vars[i]->Get(); if (input_it.numel() == 0) { continue; } MKLDNNMemoryFormat input_format = input_it.format(); - srcs_md.push_back(memory::desc(src_tz, platform::MKLDNNGetDataType(), - input_format)); + srcs_md.push_back(mkldnn::memory::desc( + src_tz, platform::MKLDNNGetDataType(), input_format)); ++num_inputs_; } std::vector scales(num_inputs_, 1.0); - auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), - MKLDNNMemoryFormat::any); + auto dst_md = mkldnn::memory::desc( + dst_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); } @@ -94,15 +87,15 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT { // (jczaja) sum oneDNN prim is not having .desc attribute so // we cannot use base AcquireForwardPrimitiveDescriptor void AcquireForwardPrimitiveDescriptor( - const memory::desc& dst_md, const std::vector& scales, - const std::vector& srcs_md) { + const mkldnn::memory::desc& dst_md, const std::vector& scales, + const std::vector& srcs_md) { // Sum op does not have backward so no passing from FWD to BWD is needed const std::string key_pd = this->key_ + "@fwd_pd"; this->fwd_pd_ = std::static_pointer_cast( this->dev_ctx_.GetBlob(key_pd)); if (this->fwd_pd_ == nullptr) { - this->fwd_pd_.reset(new mkldnn::sum::primitive_desc( - dst_md, scales, srcs_md, this->engine_)); + this->fwd_pd_.reset(new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, + this->engine_)); this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_); } } @@ -178,7 +171,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { auto sum_p = handler.AcquireForwardPrimitive(); - std::unordered_map args; + std::unordered_map args; for (size_t i = 0; i < srcs_mem.size(); ++i) { args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); } @@ -215,5 +208,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace, - paddle::operators::SumMKLDNNOpKernel); +REGISTER_OP_KERNEL( + sum, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::SumMKLDNNOpKernel, + paddle::operators::SumMKLDNNOpKernel); diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 52c4c63b473c443bb97fb7962179ce27e06fb16c..faade79091c4afcc0d0bf9625619fca1815b6db9 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -148,16 +148,19 @@ class SumOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx) && - static_cast(dtype) == - framework::proto::VarType::FP32 && + (static_cast(dtype) == + framework::proto::VarType::FP32 || + static_cast(dtype) == + framework::proto::VarType::BF16) && ctx.OutputVar("Out")->IsType()) { if (std::all_of(x_vars.begin(), x_vars.end(), [](const framework::Variable* v) { return v->IsType(); })) { return framework::OpKernelType( - framework::proto::VarType::FP32, ctx.GetPlace(), - framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); + static_cast(dtype), + ctx.GetPlace(), framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } } #endif @@ -215,6 +218,11 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor of the input. If the input is LoDTensor, the output only shares LoD information with the first input.)DOC"); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..05d739ae1f3f34e96ecdc31055d32a21e5bb044e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py @@ -0,0 +1,59 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.test_sum_op import TestSumOp +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle import enable_static +import numpy as np +import paddle.fluid.op as fluid_op + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestSumMKLDNN(TestSumOp): + def setUp(self): + self.op_type = "sum" + self.use_mkldnn = True + self.mkldnn_data_type = "bfloat16" + + # float32 input to be use for reference + x0 = np.random.random((25, 8)).astype('float32') + x1 = np.random.random((25, 8)).astype('float32') + x2 = np.random.random((25, 8)).astype('float32') + + # actual input (bf16) to bf16 sum op + x0_bf16 = convert_float_to_uint16(x0) + x1_bf16 = convert_float_to_uint16(x1) + x2_bf16 = convert_float_to_uint16(x2) + + self.inputs = {"X": [("x0", x0_bf16), ("x1", x1_bf16), ("x2", x2_bf16)]} + + y = x0 + x1 + x2 + self.outputs = {'Out': convert_float_to_uint16(y)} + self.attrs = {'use_mkldnn': self.use_mkldnn} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + pass + + +if __name__ == '__main__': + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index be1166371944123c92b5205768b18c330dd1a005..77e7372290d9c67301a36bde91f20283a324c3b0 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -602,6 +602,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_requantize_mkldnn_op', 'test_softmax_mkldnn_op', 'test_sum_mkldnn_op', + 'test_sum_bf16_mkldnn_op', 'test_transpose_int8_mkldnn_op', 'test_transpose_mkldnn_op', 'test_mkldnn_conv_activation_fuse_pass',