未验证 提交 375e5618 编写于 作者: J jakpiase 提交者: GitHub

Added PRelu BF16/FP32 FWD/BWD kernels (#33878)

* added prelu bf16/fp32 fwd/bwd kernel
上级 a0666b9d
......@@ -2262,11 +2262,26 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"elementwise_mul", "fc", "fusion_gru", "fusion_lstm", "gelu",
"layer_norm", "matmul", "matmul_v2", "pool2d", "relu", "reshape2",
"softmax", "split", "sum", "transpose2"});
std::unordered_set<std::string>({"concat",
"conv2d",
"conv2d_transpose",
"elementwise_add",
"elementwise_mul",
"fc",
"fusion_gru",
"fusion_lstm",
"gelu",
"layer_norm",
"matmul",
"matmul_v2",
"pool2d",
"prelu",
"relu",
"reshape2",
"softmax",
"split",
"sum",
"transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using dnnl::memory;
using framework::Tensor;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
namespace {
template <typename T>
class PReluMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::prelu_forward,
dnnl::prelu_backward> {
public:
PReluMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* weights,
const std::string& uniq_name, const std::string& mode,
bool is_test = false)
: platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
auto x_md = memory::desc(framework::vectorize(x->dims()),
MKLDNNGetDataType<T>(), x->format());
auto weights_dims = framework::vectorize(weights->dims());
// weights must have same size as X only for "element" case
if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
weights_dims = std::move(new_weights_dims);
}
auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType<T>(),
memory::format_tag::any);
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
x_md, weights_md);
if (!is_test)
this->AcquireBackwardPrimitiveDescriptor(x_md, weights_md, x_md,
weights_md);
}
}
std::shared_ptr<memory> AcquireWeightsMemoryPossiblyWithReorder(
const Tensor* input, const bool is_test) {
const T* input_data = input->data<T>();
// if weights are 1D, every format tag is correct, so we accept
// format_tag::any's output and no reorder is needed
if (input->dims().size() == 1) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@alpha_mem_p");
}
auto user_weights_md =
memory::desc(framework::vectorize(input->dims()),
MKLDNNGetDataType<T>(), input->format());
return this->AcquireMemoryWithReorder(
user_weights_md, this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data), "@alpha_mem_p", is_test);
}
std::shared_ptr<memory> AcquireDiffWeightsMemory(Tensor* output) {
T* output_data = output->mutable_data<T>(
this->place_, this->bwd_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(),
output_data, "@diff_weights_mem_p");
}
};
} // anonymous namespace
template <typename T>
class PReluMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X");
const auto* alpha = ctx.Input<Tensor>("Alpha");
auto* out = ctx.Output<Tensor>("Out");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, ctx.InputName("X"), mode, is_test);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
handler.AcquireWeightsMemoryPossiblyWithReorder(alpha, is_test);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto prelu_p = handler.AcquireForwardPrimitive();
auto& astream = MKLDNNDeviceContext::tls().get_stream();
prelu_p->execute(astream, {{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}});
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(GetMKLDNNFormat(*dst_memory_p));
}
};
template <typename T>
class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dalpha = ctx.Output<Tensor>(framework::GradVarName("Alpha"));
auto* alpha = ctx.Input<Tensor>("Alpha");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, framework::GradVarName("X"), mode);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
handler.AcquireWeightsMemoryPossiblyWithReorder(alpha, is_test);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto diff_weights_memory_p = handler.AcquireDiffWeightsMemory(dalpha);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto prelu_p = handler.AcquireBackwardPrimitive();
auto& astream = MKLDNNDeviceContext::tls().get_stream();
prelu_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p},
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(prelu, MKLDNN, paddle::platform::CPUPlace,
ops::PReluMKLDNNKernel<float>,
ops::PReluMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(prelu_grad, MKLDNN, paddle::platform::CPUPlace,
ops::PReluGradMKLDNNKernel<float>,
ops::PReluGradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -95,9 +95,17 @@ class PReluOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -126,6 +134,18 @@ There are modes:
)DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
}
};
......@@ -153,9 +173,17 @@ class PReluGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
def ref_prelu(x, weight, mode):
result = x.copy()
if mode == "all":
result = np.where(x > 0, x, x * weight[0])
elif mode == "channel":
if len(weight.shape) > 1:
for i in range(x.shape[1]):
result[:, i] = np.where(x[:, i] > 0, x[:, i],
x[:, i] * weight[0, i])
else:
for i in range(x.shape[1]):
result[:, i] = np.where(x[:, i] > 0, x[:, i],
x[:, i] * weight[i])
elif mode == "element":
result = np.where(x[:] > 0, x[:], x[:] * weight)
return result
class TestPReluModeChannelOneDNNOp(OpTest):
def init_attrs(self):
self.mode = "element"
self.alpha = np.random.random((1, 4, 5, 5)).astype("float32")
def set_dtype_attr(self):
pass
def set_inputs(self):
self.inputs = {'X': self.x, 'Alpha': self.alpha}
def setUp(self):
self.op_type = "prelu"
self.x = np.random.random((2, 4, 5, 5)).astype("float32") + 1
self.init_attrs()
self.set_inputs()
self.attrs = {'mode': self.mode, 'use_mkldnn': True}
self.set_dtype_attr()
self.outputs = {'Out': ref_prelu(self.x, self.alpha, self.mode)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Alpha'], 'Out')
class TestPReluModeAllOneDNNOp(TestPReluModeChannelOneDNNOp):
def init_attrs(self):
self.mode = "all"
self.alpha = np.random.random((1, 1, 1, 1)).astype("float32")
# Skip 'Alpha' input check because in mode = 'all' it has to be a single
# 1D value so checking if it has at least 100 values will cause an error
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestPReluModeElementOneDNNOp(TestPReluModeChannelOneDNNOp):
def init_attrs(self):
self.mode = "element"
self.alpha = np.random.random((1, 4, 5, 5)).astype("float32")
class TestPReluModeChannel3DOneDNNOp(TestPReluModeChannelOneDNNOp):
def init_attrs(self):
self.mode = "channel"
self.x = np.random.random((1, 100, 1)).astype("float32")
self.alpha = np.random.random((1, 100, 1)).astype("float32")
class TestPReluModeChannelAlpha1DOneDNNOp(TestPReluModeChannelOneDNNOp):
def init_attrs(self):
self.mode = "channel"
self.x = np.random.random((1, 100, 1)).astype("float32")
self.alpha = np.random.random((100)).astype("float32")
class TestPReluModeAllAlpha1DOneDNNOp(TestPReluModeAllOneDNNOp):
def init_attrs(self):
self.mode = "channel"
self.x = np.random.random((1, 1, 100)).astype("float32")
self.alpha = np.random.random((1)).astype("float32")
# BF16 TESTS
def create_bf16_test_class(parent):
class TestPReluBF16OneDNNOp(parent):
def set_inputs(self, ):
self.inputs = {
'X': convert_float_to_uint16(self.x),
'Alpha': convert_float_to_uint16(self.alpha)
}
def set_dtype_attr(self):
self.attrs['mkldnn_data_type'] = "bfloat16"
def calculate_grads(self):
dout = self.outputs['Out']
self.dx = self.x.copy()
self.dalpha = self.alpha.copy()
if self.mode == "all":
self.dx = np.where(self.x > 0, dout, dout * self.alpha[0])
elif self.mode == "channel":
if len(self.alpha.shape) > 1:
for i in range(self.x.shape[1]):
self.dx[:, i] = np.where(self.x[:, i] > 0, dout[:, i],
dout[:, i] * self.alpha[0, i])
else:
for i in range(self.x.shape[1]):
self.dx[:, i] = np.where(self.x[:, i] > 0, dout[:, i],
dout[:, i] * self.alpha[i])
self.dx
elif self.mode == "element":
self.dx = np.where(self.x[:] > 0, dout[:], dout[:] * self.alpha)
self.dalpha = np.where(self.x < 0, dout * self.x, 0)
self.dout = dout
def test_check_output(self):
if core.is_compiled_with_cuda():
self.skipTest(
"OneDNN doesn't support bf16 with CUDA, skipping UT" +
self.__class__.__name__)
elif not core.supports_bfloat16():
self.skipTest("Core doesn't support bf16, skipping UT" +
self.__class__.__name__)
else:
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
if core.is_compiled_with_cuda() or not core.supports_bfloat16():
self.skipTest(
"Core is compiled with cuda or doesn't support bf16, kipping UT"
+ self.__class__.__name__)
else:
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[
convert_float_to_uint16(self.dout)
])
cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestPReluBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestPReluBF16OneDNNOp
#TODO jakpiase
#enable bf16 tests back when oneDNN bf16 class will be ready
#create_bf16_test_class(TestPReluModeChannelOneDNNOp)
#create_bf16_test_class(TestPReluModeElementOneDNNOp)
#create_bf16_test_class(TestPReluModeChannel3DOneDNNOp)
#create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -360,7 +360,9 @@ class OpTest(unittest.TestCase):
def is_bfloat16_op(self):
return self.dtype == np.uint16 or (
hasattr(self, 'mkldnn_data_type') and
getattr(self, 'mkldnn_data_type') is "bfloat16")
getattr(self, 'mkldnn_data_type') is "bfloat16") or (
hasattr(self, 'attrs') and 'mkldnn_data_type' in self.attrs and
self.attrs['mkldnn_data_type'] == 'bfloat16')
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def is_np_data(input):
......@@ -1436,6 +1438,9 @@ class OpTest(unittest.TestCase):
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
if self.is_bfloat16_op():
check_dygraph = False
self._check_grad_helper()
if self.dtype == np.float64 and \
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST:
......
......@@ -390,6 +390,7 @@ STATIC_MODE_TESTING_LIST = [
'test_positive_negative_pair_op',
'test_precision_recall_op',
'test_prelu_op',
'test_prelu_mkldnn_op',
'test_print_op',
'test_prior_box_op',
'test_profiler',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册