未验证 提交 4e233712 编写于 作者: J jakpiase 提交者: GitHub

Added clip BF16/FP32 FWD/BWD kernels (#35601)

* implemented clip op bf16/fp32

* added skipping if not cpu or bf16

* CI rerun after bf16 package change

* added parentheses to ensure formatting
上级 b4806644
......@@ -31,6 +31,21 @@ class ClipOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename AttrType>
......@@ -54,6 +69,14 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
"input(x)");
AddAttr<AttrType>("min", "float number, the minimum value to clip by.");
AddAttr<AttrType>("max", "float number, the maximum value to clip by.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddComment(R"DOC(
Clip Operator.
......@@ -81,6 +104,21 @@ class ClipOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace {
using paddle::framework::Tensor;
template <typename T>
class ClipMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const paddle::framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
paddle::platform::ActivationMKLDNNHandler<T> handler(
mkldnn::algorithm::eltwise_clip_v2, ctx, mkldnn_engine, ctx.GetPlace(),
x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
{MKLDNN_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(paddle::platform::GetMKLDNNFormat(*dst_memory_p));
}
};
template <typename T>
class ClipGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const paddle::framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(paddle::framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(paddle::framework::GradVarName("Out"));
paddle::platform::ActivationMKLDNNHandler<T> handler(
mkldnn::algorithm::eltwise_clip_v2, ctx, mkldnn_engine, ctx.GetPlace(),
x, dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_format(paddle::platform::GetMKLDNNFormat(*diff_dst_memory_p));
}
};
} // anonymous namespace
REGISTER_OP_KERNEL(clip, MKLDNN, paddle::platform::CPUPlace,
ClipMKLDNNKernel<float>,
ClipMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(clip_grad, MKLDNN, paddle::platform::CPUPlace,
ClipGradMKLDNNKernel<float>,
ClipGradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -977,8 +977,8 @@ class ActivationMKLDNNHandler
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// eltwise_linear means we are in scale op
if (algorithm == mkldnn::algorithm::eltwise_linear) {
if (ctx.Type() == "scale") {
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
alpha = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
......@@ -988,7 +988,14 @@ class ActivationMKLDNNHandler
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) beta *= alpha;
if (!bias_after_scale) {
beta *= alpha;
}
} else if (ctx.Type() == "clip") {
alpha = ctx.HasInput("Min") ? ctx.Input<Tensor>("Min")->data<float>()[0]
: ctx.Attr<float>("min");
beta = ctx.HasInput("Max") ? ctx.Input<Tensor>("Max")->data<float>()[0]
: ctx.Attr<float>("max");
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
......@@ -1030,6 +1037,13 @@ class ActivationMKLDNNHandler
alpha = ctx.Attr<float>("threshold");
}
if (ctx.Type() == "clip_grad") {
alpha = ctx.HasInput("Min") ? ctx.Input<Tensor>("Min")->data<float>()[0]
: ctx.Attr<float>("min");
beta = ctx.HasInput("Max") ? ctx.Input<Tensor>("Max")->data<float>()[0]
: ctx.Attr<float>("max");
}
auto diff_dst_tz = framework::vectorize<int64_t>(out_grad->dims());
auto src_fmt =
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
@OpTestTool.skip_if_not_cpu_bf16()
class TestClipOneDNNOp(OpTest):
def setUp(self):
self.op_type = "clip"
self.set_inputs()
self.set_attrs()
self.set_additional_inputs()
self.adjust_op_settings()
self.min = self.attrs[
'min'] if not 'Min' in self.inputs else self.inputs['Min']
self.max = self.attrs[
'max'] if not 'Max' in self.inputs else self.inputs['Max']
self.outputs = {'Out': np.clip(self.x_fp32, self.min, self.max)}
def set_inputs(self):
self.inputs = {'X': np.random.random((10, 10)).astype(np.float32) * 25}
self.x_fp32 = self.inputs['X']
def set_additional_inputs(self):
pass
def adjust_op_settings(self):
pass
def set_attrs(self):
self.attrs = {'min': 7.2, 'max': 9.6, 'use_mkldnn': True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestClipMinAsInputOneDNNOp(TestClipOneDNNOp):
def set_additional_inputs(self):
self.inputs['Min'] = np.array([6.8]).astype('float32')
class TestClipMaxAsInputOneDNNOp(TestClipOneDNNOp):
def set_additional_inputs(self):
self.inputs['Max'] = np.array([9.1]).astype('float32')
class TestClipMaxAndMinAsInputsOneDNNOp(TestClipOneDNNOp):
def set_additional_inputs(self):
self.inputs['Max'] = np.array([8.5]).astype('float32')
self.inputs['Min'] = np.array([7.1]).astype('float32')
# BF16 TESTS
def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
class TestClipBF16OneDNNOp(parent):
def set_inputs(self):
self.x_fp32 = np.random.random((10, 10)).astype(np.float32) * 25
self.inputs = {'X': convert_float_to_uint16(self.x_fp32)}
def adjust_op_settings(self):
self.dtype = np.uint16
self.attrs['mkldnn_data_type'] = "bfloat16"
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.zeros(self.x_fp32.shape).astype("float32")
for i in range(self.dx.shape[0]):
for j in range(self.dx.shape[1]):
if self.x_fp32[j][i] > self.min and self.x_fp32[j][
i] < self.max:
self.dx[j][i] = self.dout[j][i]
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestClipBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestClipBF16OneDNNOp
create_bf16_test_class(TestClipOneDNNOp)
create_bf16_test_class(TestClipMinAsInputOneDNNOp)
create_bf16_test_class(TestClipMaxAsInputOneDNNOp)
create_bf16_test_class(TestClipMaxAndMinAsInputsOneDNNOp)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册