未验证 提交 479689f6 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Refactoring of softmax grad onednn kernel to match common API (#32851)

上级 42aad304
...@@ -15,15 +15,6 @@ limitations under the License. */ ...@@ -15,15 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace framework {
class Tensor;
} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -74,22 +65,34 @@ class SoftmaxMKLDNNHandler ...@@ -74,22 +65,34 @@ class SoftmaxMKLDNNHandler
} }
} }
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims, SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNMemoryFormat fmt, const MKLDNNDeviceContext& dev_ctx,
const MKLDNNMemoryFormat diff_fmt, const int& axis, platform::Place cpu_place, const Tensor* out,
const platform::MKLDNNDeviceContext& dev_ctx, const Tensor* out_grad, Tensor* in_x_grad,
platform::Place cpu_place, const std::string& uniq_name) const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, uniq_name)) { platform::CreateKey(dev_ctx, framework::vectorize(out->dims()),
auto data_softmax_md = unique_name)) {
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); if (!this->isBwdCached()) {
auto diff_softmax_md = PADDLE_ENFORCE_EQ(
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt); out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument("The shape of softmax_grad's input "
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, "and output must be identical."));
axis);
auto dims = out_grad->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = framework::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}
} }
}; };
...@@ -145,27 +148,15 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -145,27 +148,15 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL SoftmaxGrad must use CPUPlace")); "Operator DNNL SoftmaxGrad must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const Tensor* output = ctx.Input<Tensor>("Out"); const Tensor* output = ctx.Input<Tensor>("Out");
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out")); auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dx = auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X"));
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(
dout->dims(), dx->dims(),
platform::errors::InvalidArgument(
"The shape of softmax_grad's input and output must be identical."));
auto dims = dout->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims);
SoftmaxMKLDNNHandler<T> handler(softmax_tz, output->format(), SoftmaxMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), output,
dout->format(), axis, dev_ctx, out_grad, in_x_grad, ctx.InputName("Out"));
ctx.GetPlace(), ctx.InputName("Out"));
auto dst_memory_p = handler.AcquireDstMemory(output); auto dst_memory_p = handler.AcquireDstMemory(output);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(in_x_grad);
auto softmax_bwd_p = handler.AcquireBackwardPrimitive(); auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
...@@ -176,8 +167,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -176,8 +167,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}}); {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN); in_x_grad->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(dout->format()); in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -129,4 +129,6 @@ class TestSoftmaxMKLDNNPrimitivesAlreadyExist(unittest.TestCase): ...@@ -129,4 +129,6 @@ class TestSoftmaxMKLDNNPrimitivesAlreadyExist(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册