未验证 提交 25fc2a1f 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Added Elementwise Mul grad fp32/bf16 (#31647)

上级 878e117b
......@@ -276,7 +276,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
auto CanMKLDNNElementwiseGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();
auto dy_dims = ctx.Input<Tensor>("Y")->dims();
// No broadcast or broadcasting of data on inner dims is supported
......@@ -284,8 +284,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
};
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
(ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
CanMKLDNNElementwiseGradBeUsed()) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -61,6 +61,9 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
......@@ -75,6 +78,9 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
......@@ -86,6 +92,11 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
paddle::framework::vectorize<int64_t>(dy->dims()))));
}
}
}
......
......@@ -15,7 +15,6 @@
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
......
......@@ -14,6 +14,118 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
struct CPUPlace;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout*y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f,
ctx.InputName(framework::GradVarName("Out")));
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = dout*x
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f,
ctx.InputName(framework::GradVarName("Out")));
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_x_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
paddle::framework::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
......@@ -23,3 +135,7 @@ REGISTER_OP_KERNEL(
dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>)
REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMulMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseMulMKLDNNGradKernel<float>)
......@@ -87,6 +87,11 @@ class MKLDNNHandlerT {
"@dst_mem_p");
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), "@dstt_mem_p");
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const framework::Tensor* output) {
......@@ -561,7 +566,10 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
const auto src_x_tz = framework::vectorize(x->dims());
const auto src_y_tz = framework::vectorize(y->dims());
const auto dst_tz = framework::vectorize(z->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
const auto dst_tz =
(z == nullptr) ? src_x_tz : framework::vectorize(z->dims());
const auto src0_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
......
......@@ -30,10 +30,9 @@ class TestElementwiseMulBf16MklDNNOp(OpTest):
self.axis = -1
self.generate_data()
self.inputs = {
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y)
}
self.x_bf16 = convert_float_to_uint16(self.x)
self.y_bf16 = convert_float_to_uint16(self.y)
self.inputs = {'X': self.x_bf16, 'Y': self.y_bf16}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
......@@ -46,13 +45,66 @@ class TestElementwiseMulBf16MklDNNOp(OpTest):
self.check_output_with_place(core.CPUPlace())
def test_check_grad_normal(self):
pass
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
check_dygraph=False,
user_defined_grads=[
np.multiply(self.x, self.y), np.multiply(self.x, self.x)
],
user_defined_grad_outputs=[self.x_bf16])
def test_check_grad_ingore_x(self):
pass
self.check_grad_with_place(
core.CPUPlace(), ["Y"],
"Out",
check_dygraph=False,
user_defined_grads=[np.multiply(self.y, self.x)],
user_defined_grad_outputs=[self.y_bf16])
def test_check_grad_ingore_y(self):
pass
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
check_dygraph=False,
user_defined_grads=[np.multiply(self.x, self.y)],
user_defined_grad_outputs=[self.x_bf16])
class TestElementwiseMulBroadcastingBf16MklDNNOp(
TestElementwiseMulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.uniform(1, 2, [1, 2, 3, 100]).astype(np.float32)
self.y = np.random.uniform(1, 2, [100]).astype(np.float32)
self.out = np.multiply(self.x, self.y)
# Compute partial sums along all axes but last one
def compute_reduced_gradients(self, out_grads):
part_sum = np.add.reduceat(out_grads, [0], axis=0)
part_sum = np.add.reduceat(part_sum, [0], axis=1)
part_sum = np.add.reduceat(part_sum, [0], axis=2)
return part_sum.flatten()
def test_check_grad_normal(self):
self.check_grad_with_place(
core.CPUPlace(), ["X", "Y"],
"Out",
check_dygraph=False,
user_defined_grads=[
np.multiply(self.x, self.y),
self.compute_reduced_gradients(np.multiply(self.x, self.x))
],
user_defined_grad_outputs=[self.x_bf16])
def test_check_grad_ingore_x(self):
self.check_grad_with_place(
core.CPUPlace(), ["Y"],
"Out",
check_dygraph=False,
user_defined_grads=[
self.compute_reduced_gradients(np.multiply(self.x, self.x))
],
user_defined_grad_outputs=[self.x_bf16])
if __name__ == '__main__':
......
......@@ -17,6 +17,7 @@ import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
from paddle.fluid.tests.unittests.test_elementwise_mul_op import ElementwiseMulOp
from paddle import enable_static
class TestMKLDNNElementwiseMulOp(ElementwiseMulOp):
......@@ -51,13 +52,17 @@ class TestMKLDNNElementwiseMulOp4(TestMKLDNNElementwiseMulOp):
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestMKLDNNElementwiseMulOp5(TestMKLDNNElementwiseMulOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 100]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
''' INT8 Tests '''
......@@ -140,4 +145,5 @@ class TestUint8Scales(TestInt8Scales):
if __name__ == '__main__':
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册