未验证 提交 ca415414 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN]Sum bf16 kernel (#28382)

* - Added sum bf16 oneDNN

test=develop

* - Fix to UT of sum bf16

test=develop
上级 648b92c0
...@@ -2102,7 +2102,7 @@ PDNode *patterns::Bfloat16Placement::operator()( ...@@ -2102,7 +2102,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>(
{"concat", "conv2d", "fusion_gru", "reshape2", "transpose2"}); {"concat", "conv2d", "fusion_gru", "reshape2", "transpose2", "sum"});
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
......
...@@ -44,6 +44,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -44,6 +44,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
} else if (type == "reshape2") { } else if (type == "reshape2") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
} else if (type == "sum") {
op->SetInput("X", {inputs[0], inputs[1]});
} else { } else {
FAIL() << "Unexpected operator type."; FAIL() << "Unexpected operator type.";
} }
...@@ -61,8 +63,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -61,8 +63,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::vector<std::string>( for (auto& v :
{"a", "b", "c", "f", "g", "h", "k", "l", "m", "n", "o", "p"})) { std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m",
"n", "o", "p", "r", "s"})) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
...@@ -75,6 +78,7 @@ ProgramDesc BuildProgramDesc() { ...@@ -75,6 +78,7 @@ ProgramDesc BuildProgramDesc() {
SetOp(&prog, "concat", "concat2", {"l", "m"}, {"n"}); SetOp(&prog, "concat", "concat2", {"l", "m"}, {"n"});
SetOp(&prog, "transpose2", "transpose", {"n"}, {"o"}); SetOp(&prog, "transpose2", "transpose", {"n"}, {"o"});
SetOp(&prog, "reshape2", "reshape", {"o"}, {"p"}); SetOp(&prog, "reshape2", "reshape", {"o"}, {"p"});
SetOp(&prog, "sum", "sum", {"p", "r"}, {"s"});
return prog; return prog;
} }
...@@ -122,7 +126,7 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { ...@@ -122,7 +126,7 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
} }
TEST(Bfloat16PlacementPass, enable_all) { TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "relu", "concat"}, 7); MainTest({"conv2d", "pool2d", "relu", "concat", "sum"}, 8);
} }
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
...@@ -130,7 +134,7 @@ TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { ...@@ -130,7 +134,7 @@ TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
MainTest({"conv2d", "pool2d"}, 3); MainTest({"conv2d", "pool2d"}, 3);
} }
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(5); } TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(6); }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -40,13 +40,6 @@ class MKLDNNDeviceContext; ...@@ -40,13 +40,6 @@ class MKLDNNDeviceContext;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using mkldnn::sum;
using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext; using paddle::platform::CPUDeviceContext;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast; using platform::to_void_cast;
...@@ -71,21 +64,21 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> { ...@@ -71,21 +64,21 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
auto dst_tz = framework::vectorize<int64_t>(z->dims()); auto dst_tz = framework::vectorize<int64_t>(z->dims());
auto src_tz = dst_tz; auto src_tz = dst_tz;
std::vector<memory::desc> srcs_md; std::vector<mkldnn::memory::desc> srcs_md;
for (size_t i = 0; i < in_vars.size(); i++) { for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>(); auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) { if (input_it.numel() == 0) {
continue; continue;
} }
MKLDNNMemoryFormat input_format = input_it.format(); MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(memory::desc(src_tz, platform::MKLDNNGetDataType<T>(), srcs_md.push_back(mkldnn::memory::desc(
input_format)); src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_; ++num_inputs_;
} }
std::vector<float> scales(num_inputs_, 1.0); std::vector<float> scales(num_inputs_, 1.0);
auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(), auto dst_md = mkldnn::memory::desc(
MKLDNNMemoryFormat::any); dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
} }
...@@ -94,15 +87,15 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> { ...@@ -94,15 +87,15 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
// (jczaja) sum oneDNN prim is not having .desc attribute so // (jczaja) sum oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor // we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor( void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md, const std::vector<float>& scales, const mkldnn::memory::desc& dst_md, const std::vector<float>& scales,
const std::vector<memory::desc>& srcs_md) { const std::vector<mkldnn::memory::desc>& srcs_md) {
// Sum op does not have backward so no passing from FWD to BWD is needed // Sum op does not have backward so no passing from FWD to BWD is needed
const std::string key_pd = this->key_ + "@fwd_pd"; const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>( this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd)); this->dev_ctx_.GetBlob(key_pd));
if (this->fwd_pd_ == nullptr) { if (this->fwd_pd_ == nullptr) {
this->fwd_pd_.reset(new mkldnn::sum::primitive_desc( this->fwd_pd_.reset(new dnnl::sum::primitive_desc(dst_md, scales, srcs_md,
dst_md, scales, srcs_md, this->engine_)); this->engine_));
this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_); this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_);
} }
} }
...@@ -178,7 +171,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -178,7 +171,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_p = handler.AcquireForwardPrimitive(); auto sum_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> args; std::unordered_map<int, mkldnn::memory> args;
for (size_t i = 0; i < srcs_mem.size(); ++i) { for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
} }
...@@ -215,5 +208,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -215,5 +208,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
paddle::operators::SumMKLDNNOpKernel<float>); sum, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
paddle::operators::SumMKLDNNOpKernel<float>);
...@@ -148,16 +148,19 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -148,16 +148,19 @@ class SumOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx) && platform::CanMKLDNNBeUsed(ctx) &&
static_cast<framework::proto::VarType::Type>(dtype) == (static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::FP32 && framework::proto::VarType::FP32 ||
static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) { ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) {
if (std::all_of(x_vars.begin(), x_vars.end(), if (std::all_of(x_vars.begin(), x_vars.end(),
[](const framework::Variable* v) { [](const framework::Variable* v) {
return v->IsType<framework::LoDTensor>(); return v->IsType<framework::LoDTensor>();
})) { })) {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::VarType::FP32, ctx.GetPlace(), static_cast<framework::proto::VarType::Type>(dtype),
framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); ctx.GetPlace(), framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
} }
#endif #endif
...@@ -215,6 +218,11 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -215,6 +218,11 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
of the input. If the input is LoDTensor, the output only of the input. If the input is LoDTensor, the output only
shares LoD information with the first input.)DOC"); shares LoD information with the first input.)DOC");
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.test_sum_op import TestSumOp
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static
import numpy as np
import paddle.fluid.op as fluid_op
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestSumMKLDNN(TestSumOp):
def setUp(self):
self.op_type = "sum"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
# float32 input to be use for reference
x0 = np.random.random((25, 8)).astype('float32')
x1 = np.random.random((25, 8)).astype('float32')
x2 = np.random.random((25, 8)).astype('float32')
# actual input (bf16) to bf16 sum op
x0_bf16 = convert_float_to_uint16(x0)
x1_bf16 = convert_float_to_uint16(x1)
x2_bf16 = convert_float_to_uint16(x2)
self.inputs = {"X": [("x0", x0_bf16), ("x1", x1_bf16), ("x2", x2_bf16)]}
y = x0 + x1 + x2
self.outputs = {'Out': convert_float_to_uint16(y)}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
if __name__ == '__main__':
enable_static()
unittest.main()
...@@ -602,6 +602,7 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -602,6 +602,7 @@ STATIC_MODE_TESTING_LIST = [
'test_requantize_mkldnn_op', 'test_requantize_mkldnn_op',
'test_softmax_mkldnn_op', 'test_softmax_mkldnn_op',
'test_sum_mkldnn_op', 'test_sum_mkldnn_op',
'test_sum_bf16_mkldnn_op',
'test_transpose_int8_mkldnn_op', 'test_transpose_int8_mkldnn_op',
'test_transpose_mkldnn_op', 'test_transpose_mkldnn_op',
'test_mkldnn_conv_activation_fuse_pass', 'test_mkldnn_conv_activation_fuse_pass',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册