未验证 提交 dcf17f48 编写于 作者: A Adam 提交者: GitHub

Add isCached() mechanism to elementwise_add DNNL (#24563)

* Add isCached() mechanism to elementwise_add
test=develop

* Hide code inside handler
test=develop
上级 db0c1ea8
......@@ -25,8 +25,8 @@ namespace operators {
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::reorder;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using mkldnn::sum;
......@@ -34,51 +34,29 @@ template <typename T>
class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
const auto* x = ctx.Input<Tensor>("X");
const auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Y tensor"));
PADDLE_ENFORCE_NE(
y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor"));
auto src_x_tz = framework::vectorize<int64_t>(x->dims());
auto src_y_tz = framework::vectorize<int64_t>(y->dims());
auto dst_tz = framework::vectorize<int64_t>(z->dims());
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, src_x_tz, x->format(), y->format(),
dev_ctx, ctx.GetPlace(), ctx.OutputName("Out"));
dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, ctx.OutputName("Out"));
auto src_x_memory = handler.AcquireSrcMemory(x);
auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
// For Inplace src and and dst are the same memory object
auto dst_memory =
const auto dst_memory =
x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z);
auto binary_prim = handler.AcquireForwardPrimitive();
const auto binary_prim = handler.AcquireForwardPrimitive();
mkldnn::stream astream(mkldnn_engine);
std::unordered_map<int, dnnl::memory> args = {
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_x_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}};
......
......@@ -27,6 +27,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
using framework::DataLayout;
using framework::Tensor;
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
......@@ -108,6 +110,13 @@ class MKLDNNHandlerT {
}
protected:
bool isCached() {
const std::string key_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
return (fwd_pd_ != nullptr);
}
template <typename... Args>
void AcquireForwardPrimitiveDescriptor(Args&&... args) {
// Forward PD has to be passed to Grad op that
......@@ -355,22 +364,46 @@ class MKLDNNHandler {
template <typename T>
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
public:
BinaryMKLDNNHandler(const dnnl::algorithm algo,
const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat src0_fmt,
const MKLDNNMemoryFormat src1_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z,
const std::string uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
dev_ctx, engine, cpu_place,
platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Y tensor"));
PADDLE_ENFORCE_NE(
y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor"));
const auto src_x_tz = framework::vectorize(x->dims());
const auto src_y_tz = framework::vectorize(y->dims());
const auto dst_tz = framework::vectorize(z->dims());
// TODO(jczaja): Add function checking if data already exists
auto src0_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src0_fmt);
auto src1_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src1_fmt);
auto dst_md =
memory::desc(dims, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
const auto src0_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto src1_md = dnnl::memory::desc(
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(algo, src0_md, src1_md, dst_md);
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
this->AcquireForwardPrimitiveDescriptor(dnnl::algorithm::binary_add,
src0_md, src1_md, dst_md);
}
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册