未验证 提交 586b5875 编写于 作者: A Adam 提交者: GitHub

Add isCached() check in Softmax handler (#24637)

* Update isCached() to be thread freindly
test=develop

* Add isCached() check inside Softmax handler
test=develop

* Fix PaddleEnforce() message
test=develop
上级 3cf117db
......@@ -25,12 +25,12 @@ using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; // Note: paddle has also "memory" namespace
using mkldnn::primitive;
using mkldnn::prop_kind;
using mkldnn::softmax_backward;
using mkldnn::softmax_forward;
using mkldnn::stream;
using dnnl::memory; // Note: paddle has also "memory" namespace
using dnnl::primitive;
using dnnl::prop_kind;
using dnnl::softmax_backward;
using dnnl::softmax_forward;
using dnnl::stream;
using platform::to_void_cast;
template <typename T>
......@@ -38,19 +38,30 @@ class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> {
public:
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat fmt, const int& axis,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
Tensor* output, const int axis,
const std::string uniq_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(dims, axis, uniq_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
platform::CreateKey(framework::vectorize(input->dims()), axis,
uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
platform::errors::InvalidArgument(
"The shape of input and output tensor must be identical."));
auto softmax_tz = framework::vectorize(input->dims());
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
}
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
......@@ -76,30 +87,25 @@ template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
"The shape of softmax's input and output must be identical.");
auto dims = input->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims);
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(softmax_tz, input->format(), axis, dev_ctx,
ctx.GetPlace(), ctx.OutputName("Out"));
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, output, axis, ctx.OutputName("Out"));
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_p = handler.AcquireForwardPrimitive();
// For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = input->IsSharedBufferWith(*output)
? softmax_src_memory_p
: handler.AcquireDstMemory(output);
auto softmax_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
{DNNL_ARG_DST, *softmax_dst_memory_p}});
......
......@@ -114,7 +114,9 @@ class MKLDNNHandlerT {
const std::string key_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
return (fwd_pd_ != nullptr);
const std::string key_p = key_ + "@forward_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
}
template <typename... Args>
......@@ -367,7 +369,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z,
const std::string uniq_name)
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册