未验证 提交 e92e6b06 编写于 作者: P piotrekobiIntel 提交者: GitHub

Added fp32 / bf16 forward and backward elementwise_div_mkldnn operator (#36158)

* Add WIP version of elementwise_div_mkldnn without working dy grad

* Add dy gradient calculation implementation, disable broadcast tests

* Readd removed tests from static_mode_white_list

* Add bfloat16 gradient tests, remove int8 and uint8 support

* - Change the way dy grad is calculated to improve performance
- Refactor BinaryMKLDNNHandler to use a default parameter

* Change copyright year

* Refactor as suggested

* Attempt to bypass CI Approval
not accepting max_relative_error

* Fix formatting issue
上级 9a1cc609
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
struct CPUPlace;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseDivMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Input<framework::Tensor>("Out");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout / y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = -dout * out / y
platform::BinaryMKLDNNHandler<T> y_handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(), y,
y, nullptr, 1.0f, 1.0f, 1.0f);
const auto y_memory = y_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div, y_memory->get_desc());
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_out_memory = handler.AcquireSecondSrcMemory(out);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_out_memory},
{DNNL_ARG_DST, *dst_dy_memory},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *y_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
framework::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// TODO(piotrekobi) add int8, uint8 support
REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>)
REGISTER_OP_KERNEL(elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseDivMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseDivMKLDNNGradKernel<float>)
...@@ -614,7 +614,8 @@ class BinaryMKLDNNHandler ...@@ -614,7 +614,8 @@ class BinaryMKLDNNHandler
BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis, BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis,
const mkldnn::engine engine, platform::Place cpu_place, const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z, const Tensor* x, const Tensor* y, Tensor* z,
float scale_x, float scale_y, float scale_z) float scale_x, float scale_y, float scale_z,
const dnnl::post_ops& post_ops = dnnl::post_ops())
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) { : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN, x->layout(), DataLayout::kMKLDNN,
...@@ -663,10 +664,11 @@ class BinaryMKLDNNHandler ...@@ -663,10 +664,11 @@ class BinaryMKLDNNHandler
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z); auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z);
attributes.set_post_ops(post_ops);
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md, this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
dst_md); dst_md);
} }
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle import enable_static
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
from paddle.fluid.framework import _current_expected_place
import paddle.fluid.core as core
@OpTestTool.skip_if(not (isinstance(_current_expected_place(), core.CPUPlace)),
"GPU is not supported")
class TestMKLDNNElementwiseDivOp(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': self.out}
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', None, 0.005, False, 0.02)
def test_check_grad_ignore_x(self):
self.check_grad(['Y'], 'Out', set("X"), 0.005, False, 0.02)
def test_check_grad_ignore_y(self):
self.check_grad(['X'], 'Out', set('Y'), 0.005, False, 0.02)
def init_axis(self):
self.axis = -1
def init_kernel_type(self):
self.use_mkldnn = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
class TestMKLDNNElementwiseDivOp2(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
class TestMKLDNNElementwiseDivOp3(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
class TestMKLDNNElementwiseDivOp4(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 32]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [4, 32]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
# TODO(piotrekobiIntel): Enable when grad is ready
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
class TestMKLDNNElementwiseDivOp5(TestMKLDNNElementwiseDivOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
# TODO(piotrekobiIntel): Enable when grad is ready
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
@OpTestTool.skip_if_not_cpu_bf16()
class TestBf16(TestMKLDNNElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.x_bf16 = convert_float_to_uint16(self.x)
self.y_bf16 = convert_float_to_uint16(self.y)
self.inputs = {'X': self.x_bf16, 'Y': self.y_bf16}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
def init_dtype(self):
self.dtype = np.float32
self.mkldnn_data_type = "bfloat16"
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.out = np.divide(self.x, self.y)
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad_normal(self):
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
user_defined_grads=[
np.divide(self.x, self.y), np.divide(
(np.multiply(-self.x, self.x)), np.multiply(self.y, self.y))
],
user_defined_grad_outputs=[self.x_bf16])
def test_check_grad_ignore_x(self):
self.check_grad_with_place(
core.CPUPlace(), ["Y"],
"Out",
user_defined_grads=[
np.divide((np.multiply(-self.x, self.y)),
np.multiply(self.y, self.y))
],
user_defined_grad_outputs=[self.y_bf16])
def test_check_grad_ignore_y(self):
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[np.divide(self.x, self.y)],
user_defined_grad_outputs=[self.x_bf16])
class TestBf16Broadcasting(TestBf16):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
self.out = np.subtract(self.x, self.y)
def test_check_grad_normal(self):
pass
def test_check_grad_ignore_x(self):
pass
if __name__ == '__main__':
enable_static()
unittest.main()
...@@ -610,6 +610,7 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -610,6 +610,7 @@ STATIC_MODE_TESTING_LIST = [
'test_dequantize_mkldnn_op', 'test_dequantize_mkldnn_op',
'test_elementwise_add_mkldnn_op', 'test_elementwise_add_mkldnn_op',
'test_elementwise_add_bf16_mkldnn_op', 'test_elementwise_add_bf16_mkldnn_op',
'test_elementwise_div_mkldnn_op',
'test_elementwise_sub_mkldnn_op', 'test_elementwise_sub_mkldnn_op',
'test_elementwise_mul_mkldnn_op', 'test_elementwise_mul_mkldnn_op',
'test_elementwise_mul_bf16_mkldnn_op', 'test_elementwise_mul_bf16_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册