未验证 提交 7004f65c 编写于 作者: P piotrekobi 提交者: GitHub

Refactor elementwise op grad classes (#40187)

* Refactor elementwise op grad classes

* Add more refactor changes

* Revert set layout and format deletion

* Fix failing elementwise test
上级 2def79bc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,100 +12,8 @@ ...@@ -12,100 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto tz = phi::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler handler(
tz, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
if (dx) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
// Direct copy
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
...@@ -116,6 +24,8 @@ REGISTER_OP_KERNEL( ...@@ -116,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>, ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>)
REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseAddMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseAddMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_add>)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle { namespace ops = paddle::operators;
namespace framework {
class ExecutionContext; REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace,
} // namespace framework ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>,
namespace platform { ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
class CPUDeviceContext; dnnl::algorithm::binary_div>)
} // namespace platform
} // namespace paddle REGISTER_OP_KERNEL(
elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace,
namespace paddle { ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
namespace operators { dnnl::algorithm::binary_div>,
template <typename T> ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_div>)
class EltwiseDivMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Input<framework::Tensor>("Out");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout / y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = -dout * out / y
platform::BinaryMKLDNNHandler<T> y_handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(), y,
y, nullptr, 1.0f, 1.0f, 1.0f);
const auto y_memory = y_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div, y_memory->get_desc());
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_out_memory = handler.AcquireSecondSrcMemory(out);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_out_memory},
{DNNL_ARG_DST, *dst_dy_memory},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *y_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// TODO(piotrekobi) add int8, uint8 support
REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>)
REGISTER_OP_KERNEL(elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseDivMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseDivMKLDNNGradKernel<float>)
...@@ -15,20 +15,35 @@ ...@@ -15,20 +15,35 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::DataLayout;
using framework::Tensor;
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using dnnl::stream; using dnnl::stream;
using framework::DataLayout;
using framework::Tensor;
inline std::vector<int64_t> CalculateBroadcastedDims(const Tensor* x,
const Tensor* y) {
const auto src_tz = phi::vectorize(x->dims());
const auto dst_tz = phi::vectorize(y->dims());
size_t j = 0;
std::vector<int64_t> dst_tz_ex(src_tz.size(), 1);
for (size_t i = 0; i < src_tz.size(); ++i) {
dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
if (j == dst_tz.size()) break;
}
return dst_tz_ex;
}
template <typename T, dnnl::algorithm BINARY_OP> template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> { class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
...@@ -103,7 +118,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -103,7 +118,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
// operation. // operation.
const bool reuse_x_memopry = const bool reuse_x_memopry =
x->numel() == z->numel() && x->IsSharedBufferWith(*z); x->numel() == z->numel() && x->IsSharedBufferWith(*z);
std::shared_ptr<dnnl::memory> dst_memory = nullptr; std::shared_ptr<dnnl::memory> dst_memory;
if (reuse_x_memopry) { if (reuse_x_memopry) {
dst_memory = src_x_memory; dst_memory = src_x_memory;
// NOTE(chenfeiyu): when the output reuses memory from other tensor rather // NOTE(chenfeiyu): when the output reuses memory from other tensor rather
...@@ -135,19 +150,193 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -135,19 +150,193 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
} }
}; };
inline std::vector<int64_t> CalculateBroadcastedDims(const Tensor* x, template <typename T, dnnl::algorithm BINARY_OP>
const Tensor* y) { class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
const auto src_tz = phi::vectorize(x->dims()); public:
const auto dst_tz = phi::vectorize(y->dims()); void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
size_t j = 0; auto& dev_ctx =
std::vector<int64_t> dst_tz_ex(src_tz.size(), 1); ctx.template device_context<platform::MKLDNNDeviceContext>();
for (size_t i = 0; i < src_tz.size(); ++i) { const auto& onednn_engine = dev_ctx.GetEngine();
dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
if (j == dst_tz.size()) break;
}
return dst_tz_ex; auto* x = ctx.Input<Tensor>("X");
} auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
auto tz = phi::vectorize<int64_t>(dout->dims());
auto proto_type_dout = framework::TransToProtoVarType(dout->dtype());
platform::ReorderMKLDNNHandler reorder_handler(
tz, proto_type_dout, framework::ToMKLDNNDataType(proto_type_dout),
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
std::shared_ptr<dnnl::memory> dst_memory;
// elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) {
dst_memory = reorder_handler.AcquireDstMemory(dx, dout->format(),
ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
}
// elementwise_mul & elementwise_div
else {
platform::BinaryMKLDNNHandler<T> binary_handler(
BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f,
1.0f, 1.0f);
const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
dst_memory = binary_handler.AcquireDstMemory(dx);
const auto binary_prim = binary_handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}};
binary_prim->execute(astream, args);
}
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
if (dy) {
dnnl::primitive_attr broadcast_reduction_attr;
std::shared_ptr<dnnl::memory> broadcast_src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
// elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) {
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dy, dout->format(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr;
std::vector<float> scales(1);
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
dst_memory = reorder_dst_memory_p;
} else {
broadcast_src_memory = reorder_src_memory_p;
}
}
// elementwise_mul & elementwise_div
else {
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::binary> binary_prim;
std::shared_ptr<dnnl::memory> post_op_memory;
std::shared_ptr<dnnl::memory> src_0_memory;
std::shared_ptr<dnnl::memory> src_1_memory;
platform::BinaryMKLDNNHandler<T> binary_handler(
dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
src_1_memory = binary_handler.AcquireSecondSrcMemory(x);
if (BINARY_OP == dnnl::algorithm::binary_div) {
platform::BinaryMKLDNNHandler<T> post_op_binary_handler(
dnnl::algorithm::binary_div, axis, onednn_engine, ctx.GetPlace(),
y, y, nullptr, 1.0f, 1.0f, 1.0f);
post_op_memory = post_op_binary_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div,
post_op_memory->get_desc());
binary_handler = platform::BinaryMKLDNNHandler<T>(
dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
}
src_0_memory = binary_handler.AcquireSrcMemory(dout);
const auto dst_dy_memory = (dout->dims() == dy->dims())
? binary_handler.AcquireDstMemory(dy)
: binary_handler.AcquireDstMemory();
binary_prim = binary_handler.AcquireForwardPrimitive();
args = {{DNNL_ARG_SRC_0, *src_0_memory},
{DNNL_ARG_SRC_1, *src_1_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
if (BINARY_OP == dnnl::algorithm::binary_div)
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*post_op_memory});
binary_prim->execute(astream, args);
broadcast_src_memory = dst_dy_memory;
dst_memory = dst_dy_memory;
}
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
if (dout->dims() != dy->dims()) {
// Broadcasting
if (BINARY_OP == dnnl::algorithm::binary_sub) {
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
broadcast_reduction_attr.set_post_ops(po);
}
platform::ReductionMKLDNNHandler<T> reduction_handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy),
broadcast_reduction_attr);
dst_memory = reduction_handler.AcquireDstMemory(dy);
auto reduction_p = reduction_handler.AcquireForwardPrimitive();
reduction_p->execute(astream, {
{DNNL_ARG_SRC, *broadcast_src_memory},
{DNNL_ARG_DST, *dst_memory},
});
astream.wait();
dy->set_format(platform::GetMKLDNNFormat(dst_memory->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout*y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = dout*x
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_x_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
...@@ -132,6 +24,8 @@ REGISTER_OP_KERNEL( ...@@ -132,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>, ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>)
REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseMulMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMulMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_mul>)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,113 +12,7 @@ ...@@ -13,113 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto tz = phi::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler handler(
tz, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
if (dx) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
// Direct copy
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr;
std::vector<float> scales = {-1};
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
dnnl::primitive_attr attr;
attr.set_post_ops(po);
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr);
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {
{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p},
});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -131,6 +24,8 @@ REGISTER_OP_KERNEL( ...@@ -131,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>, ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>)
REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseSubMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseSubMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_sub>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_sub>)
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
......
...@@ -23,7 +23,7 @@ import paddle.fluid.core as core ...@@ -23,7 +23,7 @@ import paddle.fluid.core as core
from paddle.fluid import Program, compiler, program_guard from paddle.fluid import Program, compiler, program_guard
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
class ElementwiseMulOp(OpTest): class ElementwiseMulOp(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册