未验证 提交 479689f6 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Refactoring of softmax grad onednn kernel to match common API (#32851)

上级 42aad304
......@@ -15,15 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace framework {
class Tensor;
} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......@@ -74,23 +65,35 @@ class SoftmaxMKLDNNHandler
}
}
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt, const int& axis,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const Tensor* out,
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, uniq_name)) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
platform::CreateKey(dev_ctx, framework::vectorize(out->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(
out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument("The shape of softmax_grad's input "
"and output must be identical."));
auto dims = out_grad->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = framework::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}
}
};
template <typename T>
......@@ -145,27 +148,15 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL SoftmaxGrad must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const Tensor* output = ctx.Input<Tensor>("Out");
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(
dout->dims(), dx->dims(),
platform::errors::InvalidArgument(
"The shape of softmax_grad's input and output must be identical."));
auto dims = dout->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims);
auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, output->format(),
dout->format(), axis, dev_ctx,
ctx.GetPlace(), ctx.InputName("Out"));
SoftmaxMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), output,
out_grad, in_x_grad, ctx.InputName("Out"));
auto dst_memory_p = handler.AcquireDstMemory(output);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(in_x_grad);
auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
......@@ -176,8 +167,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(dout->format());
in_x_grad->set_layout(framework::DataLayout::kMKLDNN);
in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
}
};
} // namespace operators
......
......@@ -129,4 +129,6 @@ class TestSoftmaxMKLDNNPrimitivesAlreadyExist(unittest.TestCase):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册