未验证 提交 ca415414 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN]Sum bf16 kernel (#28382)

* - Added sum bf16 oneDNN

test=develop

* - Fix to UT of sum bf16

test=develop
上级 648b92c0
......@@ -2102,7 +2102,7 @@ PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "fusion_gru", "reshape2", "transpose2"});
{"concat", "conv2d", "fusion_gru", "reshape2", "transpose2", "sum"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
......@@ -44,6 +44,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", {inputs[0]});
} else if (type == "reshape2") {
op->SetInput("X", {inputs[0]});
} else if (type == "sum") {
op->SetInput("X", {inputs[0], inputs[1]});
} else {
FAIL() << "Unexpected operator type.";
}
......@@ -61,8 +63,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "c", "f", "g", "h", "k", "l", "m", "n", "o", "p"})) {
for (auto& v :
std::vector<std::string>({"a", "b", "c", "f", "g", "h", "k", "l", "m",
"n", "o", "p", "r", "s"})) {
prog.MutableBlock(0)->Var(v);
}
......@@ -75,6 +78,7 @@ ProgramDesc BuildProgramDesc() {
SetOp(&prog, "concat", "concat2", {"l", "m"}, {"n"});
SetOp(&prog, "transpose2", "transpose", {"n"}, {"o"});
SetOp(&prog, "reshape2", "reshape", {"o"}, {"p"});
SetOp(&prog, "sum", "sum", {"p", "r"}, {"s"});
return prog;
}
......@@ -122,7 +126,7 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
}
TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "relu", "concat"}, 7);
MainTest({"conv2d", "pool2d", "relu", "concat", "sum"}, 8);
}
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
......@@ -130,7 +134,7 @@ TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
MainTest({"conv2d", "pool2d"}, 3);
}
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(5); }
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(6); }
} // namespace ir
} // namespace framework
......
......@@ -40,13 +40,6 @@ class MKLDNNDeviceContext;
namespace paddle {
namespace operators {
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using mkldnn::sum;
using paddle::framework::Tensor;
using paddle::platform::CPUDeviceContext;
using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast;
......@@ -71,21 +64,21 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
auto dst_tz = framework::vectorize<int64_t>(z->dims());
auto src_tz = dst_tz;
std::vector<memory::desc> srcs_md;
std::vector<mkldnn::memory::desc> srcs_md;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(memory::desc(src_tz, platform::MKLDNNGetDataType<T>(),
input_format));
srcs_md.push_back(mkldnn::memory::desc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format));
++num_inputs_;
}
std::vector<float> scales(num_inputs_, 1.0);
auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
auto dst_md = mkldnn::memory::desc(
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
}
......@@ -94,15 +87,15 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
// (jczaja) sum oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md, const std::vector<float>& scales,
const std::vector<memory::desc>& srcs_md) {
const mkldnn::memory::desc& dst_md, const std::vector<float>& scales,
const std::vector<mkldnn::memory::desc>& srcs_md) {
// Sum op does not have backward so no passing from FWD to BWD is needed
const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
if (this->fwd_pd_ == nullptr) {
this->fwd_pd_.reset(new mkldnn::sum::primitive_desc(
dst_md, scales, srcs_md, this->engine_));
this->fwd_pd_.reset(new dnnl::sum::primitive_desc(dst_md, scales, srcs_md,
this->engine_));
this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_);
}
}
......@@ -178,7 +171,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> args;
std::unordered_map<int, mkldnn::memory> args;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
......@@ -215,5 +208,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(
sum, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
paddle::operators::SumMKLDNNOpKernel<float>);
......@@ -148,16 +148,19 @@ class SumOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx) &&
static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::FP32 &&
(static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::FP32 ||
static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) {
if (std::all_of(x_vars.begin(), x_vars.end(),
[](const framework::Variable* v) {
return v->IsType<framework::LoDTensor>();
})) {
return framework::OpKernelType(
framework::proto::VarType::FP32, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN);
static_cast<framework::proto::VarType::Type>(dtype),
ctx.GetPlace(), framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
}
#endif
......@@ -215,6 +218,11 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
of the input. If the input is LoDTensor, the output only
shares LoD information with the first input.)DOC");
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.test_sum_op import TestSumOp
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static
import numpy as np
import paddle.fluid.op as fluid_op
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestSumMKLDNN(TestSumOp):
def setUp(self):
self.op_type = "sum"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
# float32 input to be use for reference
x0 = np.random.random((25, 8)).astype('float32')
x1 = np.random.random((25, 8)).astype('float32')
x2 = np.random.random((25, 8)).astype('float32')
# actual input (bf16) to bf16 sum op
x0_bf16 = convert_float_to_uint16(x0)
x1_bf16 = convert_float_to_uint16(x1)
x2_bf16 = convert_float_to_uint16(x2)
self.inputs = {"X": [("x0", x0_bf16), ("x1", x1_bf16), ("x2", x2_bf16)]}
y = x0 + x1 + x2
self.outputs = {'Out': convert_float_to_uint16(y)}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
if __name__ == '__main__':
enable_static()
unittest.main()
......@@ -602,6 +602,7 @@ STATIC_MODE_TESTING_LIST = [
'test_requantize_mkldnn_op',
'test_softmax_mkldnn_op',
'test_sum_mkldnn_op',
'test_sum_bf16_mkldnn_op',
'test_transpose_int8_mkldnn_op',
'test_transpose_mkldnn_op',
'test_mkldnn_conv_activation_fuse_pass',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册