提交 cb65439d 编写于 作者: A Adam 提交者: Tao Luo

Add support for other axes in MKLDNN softmax op (#19907)

* Initial, functional commit

* Clean commit related files
test=develop
上级 45425411
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <numeric>
#include "mkldnn.hpp"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
......@@ -38,37 +39,37 @@ class SoftmaxMKLDNNHandler
mkldnn::softmax_backward> {
public:
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat fmt, const int& axis,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
platform::CreateKey(dims, axis, uniq_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
1 /*dim: C*/);
axis);
}
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt,
const MKLDNNMemoryFormat diff_fmt, const int& axis,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
platform::CreateKey(dims, axis, uniq_name)) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, 1 /*dim: C*/);
data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
1 /* dim: C*/);
axis);
}
};
......@@ -85,18 +86,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
input->dims(), output->dims(),
"The shape of softmax's input and output must be identical.");
// flatten input and output to 2-D matrixs
auto dims = input->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto src_tz = paddle::framework::vectorize<int>(flattened_dims);
auto dst_tz = src_tz;
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
auto softmax_tz = paddle::framework::vectorize<int>(dims);
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
SoftmaxMKLDNNHandler<T> handler(softmax_tz, input->format(), axis, dev_ctx,
ctx.GetPlace(), ctx.op().Output("Out"));
// Currently only NC data format is supported
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_dst_memory_p = handler.AcquireDstMemory(output);
auto softmax_p = handler.AcquireForwardPrimitive(*softmax_src_memory_p,
......@@ -105,14 +102,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline{*softmax_p};
stream(stream::kind::eager).submit(pipeline).wait();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) {
T threshold = exp(-64);
for (int i = 0; i < dst_tz[0] * dst_tz[1]; ++i) {
output_data[i] =
output_data[i] < threshold ? threshold : output_data[i];
}
T* output_data = output->mutable_data<T>(ctx.GetPlace());
int size = std::accumulate(begin(softmax_tz), end(softmax_tz), 1,
std::multiplies<int>());
std::for_each(output_data, &output_data[size], [](T& val) {
val = std::max(val, static_cast<T>(exp(-64)));
});
}
output->set_layout(framework::DataLayout::kMKLDNN);
......@@ -139,31 +136,26 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
"The shape of softmax_grad's input and output must be identical.");
auto dims = dout->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
std::vector<int> dst_tz = paddle::framework::vectorize<int>(flattened_dims);
std::vector<int> src_tz(dst_tz);
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
std::vector<int> softmax_tz = paddle::framework::vectorize<int>(dims);
// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
// Normalization is made after innermost dimension eg. C out of NC
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc,
MKLDNNMemoryFormat::nc, dev_ctx,
SoftmaxMKLDNNHandler<T> handler(softmax_tz, output->format(),
dout->format(), axis, dev_ctx,
ctx.GetPlace(), ctx.op().Input("Out"));
auto dst_memory_p = handler.AcquireDstMemory(output);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
// Get primitve from device context
auto softmax_bwd_p = handler.AcquireBackwardPrimitive(
*dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
std::vector<primitive> pipeline{*softmax_bwd_p};
stream(stream::kind::eager).submit(pipeline).wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(dout->format());
}
};
} // namespace operators
......
......@@ -47,10 +47,8 @@ class SoftmaxOp : public framework::OperatorWithKernel {
"R is the rank of Input(X).");
auto use_cudnn = ctx->Attrs().Get<bool>("use_cudnn");
auto use_mkldnn = ctx->Attrs().Get<bool>("use_mkldnn");
if (axis != rank_x - 1 && axis != -1) {
PADDLE_ENFORCE(!use_cudnn, "CUDNN kernel only support axis as -1.");
PADDLE_ENFORCE(!use_mkldnn, "MKLDNN kernel only support axis as -1.");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
......
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp, stable_softmax
from paddle.fluid.tests.unittests.test_softmax_op import *
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
......@@ -27,9 +27,29 @@ class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
class TestSoftmaxMKLDNNOp2(TestSoftmaxOp2):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp3(TestSoftmaxOp3):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp4(TestSoftmaxOp4):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp5(TestSoftmaxOp5):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp6(TestSoftmaxOp6):
def init_kernel_type(self):
self.use_mkldnn = True
# Check if primitives already exist in backward
......
......@@ -103,7 +103,7 @@ class TestSoftmaxOp5(TestSoftmaxOp):
return 2
class TestSoftmaxOp5(TestSoftmaxOp):
class TestSoftmaxOp6(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册