未验证 提交 a2a45d8d 编写于 作者: J jakpiase 提交者: GitHub

Added cast op oneDNN kernel for bf16/fp32 datatypes casting(FWD/BWD) (#33056)

* added op cast functionality for fp32/bf16

* added newline

* added entries in static mode white list and unity build

* fixed failing tests

* changes after review

* added formatting

* upgraded tests file as reviewer suggested

* changes after review

* minor change
上级 009ff61b
......@@ -27,6 +27,9 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of cast op");
AddAttr<int>("out_dtype", "output data type");
AddAttr<int>("in_dtype", "input data type");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Cast Operator.
......@@ -50,6 +53,7 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
grad->SetOutput("Out", this->InputGrad("X"));
grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
grad->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn"));
}
};
......@@ -77,6 +81,28 @@ class CastOp : public framework::OperatorWithKernel {
if (platform::is_cuda_pinned_place(tensor_place)) {
return framework::OpKernelType(tensor->type(), ctx.device_context());
}
#ifdef PADDLE_WITH_MKLDNN
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");
auto MKLDNNSupportsCast = [&]() -> bool {
int dtype_fp32 = static_cast<int>(framework::proto::VarType::FP32);
int dtype_bf16 = static_cast<int>(framework::proto::VarType::BF16);
if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) ||
(out_dtype != dtype_fp32 && out_dtype != dtype_bf16))
return false;
return true;
};
if (this->CanMKLDNNBeUsed(ctx, tensor->type()) && MKLDNNSupportsCast()) {
return framework::OpKernelType(tensor->type(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}
};
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
template <typename T>
class CastMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");
auto x_paddle_type = framework::proto::VarType::Type(in_dtype);
auto out_paddle_type = framework::proto::VarType::Type(out_dtype);
mkldnn::memory::data_type x_type =
framework::ToMKLDNNDataType(x_paddle_type);
mkldnn::memory::data_type out_type =
framework::ToMKLDNNDataType(out_paddle_type);
auto x_tz = framework::vectorize(x->dims());
std::string key =
platform::CreateKey(dev_ctx, x_tz, x->format(), x->format(), x_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_tz, x_paddle_type, x_type, out_paddle_type, out_type, dev_ctx,
dev_ctx.GetEngine(), key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x->format(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(cast, MKLDNN, paddle::platform::CPUPlace,
ops::CastMKLDNNKernel<float>,
ops::CastMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -30,6 +30,7 @@ register_unity_group(cc
bmm_op.cc
bpr_loss_op.cc
cast_op.cc
mkldnn/cast_mkldnn_op.cc
cholesky_op.cc
chunk_eval_op.cc
clip_by_norm_op.cc
......
......@@ -926,7 +926,23 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype),
dtype_(dtype) {}
vtype_dst_(vtype),
dtype_(dtype),
dtype_dst_(dtype) {}
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype,
framework::proto::VarType::Type vtype_dst,
mkldnn::memory::data_type dtype_dst,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
vtype_(vtype),
vtype_dst_(vtype_dst),
dtype_(dtype),
dtype_dst_(dtype_dst) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const MKLDNNMemoryFormat& fmt, void* ptr) {
......@@ -940,15 +956,16 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_data = output->mutable_data(place, vtype_, dst_md.get_size());
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data =
output->mutable_data(place, vtype_dst_, dst_md.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
// Even if memory object exists , we may be using it for diffrent tensor
auto dst_data =
output->mutable_data(place, vtype_, mem_p->get_desc().get_size());
output->mutable_data(place, vtype_dst_, mem_p->get_desc().get_size());
mem_p->set_data_handle(dst_data);
}
return mem_p;
......@@ -970,8 +987,8 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
private:
std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_;
mkldnn::memory::data_type dtype_;
framework::proto::VarType::Type vtype_, vtype_dst_;
mkldnn::memory::data_type dtype_, dtype_dst_;
};
template <typename T>
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestCastBF16ToFP32MKLDNNOp(OpTest):
def init_data(self):
self.out = np.random.random(size=[10, 10]).astype("float32")
self.x = convert_float_to_uint16(self.out)
def setUp(self):
self.init_data()
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
prepare_dtype = lambda x: int(core.VarDesc.VarType.BF16 if x.dtype != np.float32 else core.VarDesc.VarType.FP32)
self.attrs = {
'in_dtype': prepare_dtype(self.x),
'out_dtype': prepare_dtype(self.out),
'use_mkldnn': True
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output(check_dygraph=False)
def test_check_grad(self):
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
check_dygraph=False,
user_defined_grads=[self.inputs['X']],
user_defined_grad_outputs=[self.outputs['Out']])
class TestCastFP32ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
def init_data(self):
self.x = np.random.random(size=[2, 6]).astype("float32")
self.out = convert_float_to_uint16(self.x)
class TestCastBF16ToBF16MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
def init_data(self):
self.x = np.random.random(size=[6, 13]).astype("uint16")
self.out = self.x
class TestCastFP32ToFP32MKLDNNOp(TestCastBF16ToFP32MKLDNNOp):
def init_data(self):
self.x = np.random.random(size=[7, 15]).astype("float32")
self.out = self.x
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -1191,8 +1191,12 @@ class OpTest(unittest.TestCase):
np.float32, np.float64
]:
actual_t = convert_uint16_to_float(actual_t)
atol = 0.03
atol = max(atol, 0.03)
if expect_t.dtype == np.uint16 and actual_t.dtype == np.uint16:
expect_t = convert_uint16_to_float(expect_t)
actual_t = convert_uint16_to_float(actual_t)
atol = max(atol, 0.03)
# NOTE(zhiqiu): np.allclose([], [1.]) returns True
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_t.size == 0:
......@@ -1501,13 +1505,21 @@ class OpTest(unittest.TestCase):
# comparison of bf16 results will happen as fp32
# loop over list of grads and convert bf16 to fp32
fp32_grads = []
fp32_analytic_grads = []
for grad in analytic_grads:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03
fp32_grads.append(grad)
analytic_grads = fp32_grads
fp32_analytic_grads.append(grad)
analytic_grads = fp32_analytic_grads
fp32_numeric_grads = []
for grad in numeric_grads:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03
fp32_numeric_grads.append(grad)
numeric_grads = fp32_numeric_grads
self._assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
max_relative_error,
......
......@@ -589,6 +589,7 @@ STATIC_MODE_TESTING_LIST = [
'test_matmul_op_with_head',
'test_var_conv_2d',
'test_batch_norm_mkldnn_op',
'test_cast_mkldnn_op',
'test_concat_int8_mkldnn_op',
'test_concat_bf16_mkldnn_op',
'test_concat_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册