未验证 提交 7004f65c 编写于 作者: P piotrekobi 提交者: GitHub

Refactor elementwise op grad classes (#40187)

* Refactor elementwise op grad classes

* Add more refactor changes

* Revert set layout and format deletion

* Fix failing elementwise test
上级 2def79bc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,100 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto tz = phi::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler handler(
tz, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
if (dx) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
// Direct copy
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
......@@ -116,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>)
REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseAddMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseAddMKLDNNGradKernel<float>)
REGISTER_OP_KERNEL(
elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_add>)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseDivMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Input<framework::Tensor>("Out");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout / y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = -dout * out / y
platform::BinaryMKLDNNHandler<T> y_handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(), y,
y, nullptr, 1.0f, 1.0f, 1.0f);
const auto y_memory = y_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div, y_memory->get_desc());
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_out_memory = handler.AcquireSecondSrcMemory(out);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_out_memory},
{DNNL_ARG_DST, *dst_dy_memory},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *y_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// TODO(piotrekobi) add int8, uint8 support
REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>)
REGISTER_OP_KERNEL(elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseDivMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseDivMKLDNNGradKernel<float>)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>)
REGISTER_OP_KERNEL(
elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_div>)
......@@ -15,20 +15,35 @@
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using framework::DataLayout;
using framework::Tensor;
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
using framework::DataLayout;
using framework::Tensor;
inline std::vector<int64_t> CalculateBroadcastedDims(const Tensor* x,
const Tensor* y) {
const auto src_tz = phi::vectorize(x->dims());
const auto dst_tz = phi::vectorize(y->dims());
size_t j = 0;
std::vector<int64_t> dst_tz_ex(src_tz.size(), 1);
for (size_t i = 0; i < src_tz.size(); ++i) {
dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
if (j == dst_tz.size()) break;
}
return dst_tz_ex;
}
template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
......@@ -103,7 +118,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
// operation.
const bool reuse_x_memopry =
x->numel() == z->numel() && x->IsSharedBufferWith(*z);
std::shared_ptr<dnnl::memory> dst_memory = nullptr;
std::shared_ptr<dnnl::memory> dst_memory;
if (reuse_x_memopry) {
dst_memory = src_x_memory;
// NOTE(chenfeiyu): when the output reuses memory from other tensor rather
......@@ -135,19 +150,193 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
}
};
inline std::vector<int64_t> CalculateBroadcastedDims(const Tensor* x,
const Tensor* y) {
const auto src_tz = phi::vectorize(x->dims());
const auto dst_tz = phi::vectorize(y->dims());
template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
size_t j = 0;
std::vector<int64_t> dst_tz_ex(src_tz.size(), 1);
for (size_t i = 0; i < src_tz.size(); ++i) {
dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
if (j == dst_tz.size()) break;
}
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
return dst_tz_ex;
}
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
auto tz = phi::vectorize<int64_t>(dout->dims());
auto proto_type_dout = framework::TransToProtoVarType(dout->dtype());
platform::ReorderMKLDNNHandler reorder_handler(
tz, proto_type_dout, framework::ToMKLDNNDataType(proto_type_dout),
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
std::shared_ptr<dnnl::memory> dst_memory;
// elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) {
dst_memory = reorder_handler.AcquireDstMemory(dx, dout->format(),
ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
}
// elementwise_mul & elementwise_div
else {
platform::BinaryMKLDNNHandler<T> binary_handler(
BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f,
1.0f, 1.0f);
const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
dst_memory = binary_handler.AcquireDstMemory(dx);
const auto binary_prim = binary_handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}};
binary_prim->execute(astream, args);
}
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
if (dy) {
dnnl::primitive_attr broadcast_reduction_attr;
std::shared_ptr<dnnl::memory> broadcast_src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
// elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) {
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dy, dout->format(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr;
std::vector<float> scales(1);
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
dst_memory = reorder_dst_memory_p;
} else {
broadcast_src_memory = reorder_src_memory_p;
}
}
// elementwise_mul & elementwise_div
else {
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::binary> binary_prim;
std::shared_ptr<dnnl::memory> post_op_memory;
std::shared_ptr<dnnl::memory> src_0_memory;
std::shared_ptr<dnnl::memory> src_1_memory;
platform::BinaryMKLDNNHandler<T> binary_handler(
dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
src_1_memory = binary_handler.AcquireSecondSrcMemory(x);
if (BINARY_OP == dnnl::algorithm::binary_div) {
platform::BinaryMKLDNNHandler<T> post_op_binary_handler(
dnnl::algorithm::binary_div, axis, onednn_engine, ctx.GetPlace(),
y, y, nullptr, 1.0f, 1.0f, 1.0f);
post_op_memory = post_op_binary_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div,
post_op_memory->get_desc());
binary_handler = platform::BinaryMKLDNNHandler<T>(
dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
}
src_0_memory = binary_handler.AcquireSrcMemory(dout);
const auto dst_dy_memory = (dout->dims() == dy->dims())
? binary_handler.AcquireDstMemory(dy)
: binary_handler.AcquireDstMemory();
binary_prim = binary_handler.AcquireForwardPrimitive();
args = {{DNNL_ARG_SRC_0, *src_0_memory},
{DNNL_ARG_SRC_1, *src_1_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
if (BINARY_OP == dnnl::algorithm::binary_div)
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*post_op_memory});
binary_prim->execute(astream, args);
broadcast_src_memory = dst_dy_memory;
dst_memory = dst_dy_memory;
}
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
if (dout->dims() != dy->dims()) {
// Broadcasting
if (BINARY_OP == dnnl::algorithm::binary_sub) {
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
broadcast_reduction_attr.set_post_ops(po);
}
platform::ReductionMKLDNNHandler<T> reduction_handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy),
broadcast_reduction_attr);
dst_memory = reduction_handler.AcquireDstMemory(dy);
auto reduction_p = reduction_handler.AcquireForwardPrimitive();
reduction_p->execute(astream, {
{DNNL_ARG_SRC, *broadcast_src_memory},
{DNNL_ARG_DST, *dst_memory},
});
astream.wait();
dy->set_format(platform::GetMKLDNNFormat(dst_memory->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout*y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = dout*x
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_x_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
......@@ -132,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>)
REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMulMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseMulMKLDNNGradKernel<float>)
REGISTER_OP_KERNEL(
elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_mul>)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,113 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto tz = phi::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler handler(
tz, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
if (dx) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
// Direct copy
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr;
std::vector<float> scales = {-1};
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
dnnl::primitive_attr attr;
attr.set_post_ops(po);
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr);
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {
{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p},
});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......@@ -131,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>)
REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseSubMKLDNNGradKernel<paddle::platform::bfloat16>,
ops::EltwiseSubMKLDNNGradKernel<float>)
REGISTER_OP_KERNEL(
elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_sub>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_sub>)
......@@ -17,7 +17,7 @@ import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
......
......@@ -23,7 +23,7 @@ import paddle.fluid.core as core
from paddle.fluid import Program, compiler, program_guard
from paddle.fluid.op import Operator
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
class ElementwiseMulOp(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册