未验证 提交 586b5875 编写于 作者: A Adam 提交者: GitHub

Add isCached() check in Softmax handler (#24637)

* Update isCached() to be thread freindly
test=develop

* Add isCached() check inside Softmax handler
test=develop

* Fix PaddleEnforce() message
test=develop
上级 3cf117db
...@@ -25,12 +25,12 @@ using paddle::framework::Tensor; ...@@ -25,12 +25,12 @@ using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc; using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; // Note: paddle has also "memory" namespace using dnnl::memory; // Note: paddle has also "memory" namespace
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::prop_kind; using dnnl::prop_kind;
using mkldnn::softmax_backward; using dnnl::softmax_backward;
using mkldnn::softmax_forward; using dnnl::softmax_forward;
using mkldnn::stream; using dnnl::stream;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T> template <typename T>
...@@ -38,19 +38,30 @@ class SoftmaxMKLDNNHandler ...@@ -38,19 +38,30 @@ class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> { mkldnn::softmax_backward> {
public: public:
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims, SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const MKLDNNMemoryFormat fmt, const int& axis, const mkldnn::engine mkldnn_engine,
const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const Tensor* input,
platform::Place cpu_place, const std::string& uniq_name) Tensor* output, const int axis,
const std::string uniq_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique // Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(dims, axis, uniq_name)) { platform::CreateKey(framework::vectorize(input->dims()), axis,
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); uniq_name)) {
if (!this->isCached()) {
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, PADDLE_ENFORCE_EQ(
axis); input->dims(), output->dims(),
platform::errors::InvalidArgument(
"The shape of input and output tensor must be identical."));
auto softmax_tz = framework::vectorize(input->dims());
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
} }
SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims, SoftmaxMKLDNNHandler(const std::vector<int64_t>& dims,
...@@ -76,30 +87,25 @@ template <typename T> ...@@ -76,30 +87,25 @@ template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
"The shape of softmax's input and output must be identical.");
auto dims = input->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = paddle::framework::vectorize<int64_t>(dims); const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(softmax_tz, input->format(), axis, dev_ctx, SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
ctx.GetPlace(), ctx.OutputName("Out")); input, output, axis, ctx.OutputName("Out"));
auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_p = handler.AcquireForwardPrimitive();
// For Inplace src and and dst are the same memory object // For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = input->IsSharedBufferWith(*output) auto softmax_dst_memory_p = input->IsSharedBufferWith(*output)
? softmax_src_memory_p ? softmax_src_memory_p
: handler.AcquireDstMemory(output); : handler.AcquireDstMemory(output);
auto softmax_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine()); mkldnn::stream astream(dev_ctx.GetEngine());
softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p}, softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
{DNNL_ARG_DST, *softmax_dst_memory_p}}); {DNNL_ARG_DST, *softmax_dst_memory_p}});
......
...@@ -114,7 +114,9 @@ class MKLDNNHandlerT { ...@@ -114,7 +114,9 @@ class MKLDNNHandlerT {
const std::string key_pd = key_common_ + "@forward_pd"; const std::string key_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
return (fwd_pd_ != nullptr);
const std::string key_p = key_ + "@forward_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
} }
template <typename... Args> template <typename... Args>
...@@ -367,7 +369,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> { ...@@ -367,7 +369,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place, const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z, const Tensor* x, const Tensor* y, Tensor* z,
const std::string uniq_name) const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>( : platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) { platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册