提交 c2efdfd5 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Extending reusing to Elementwise_add_mkldnn op (#18146)

* - Reusing of reuder used in elementwise_add_mkldnn

- Added MKL-DNN sum prim reusing

test=develop

- Compilation fixes

test=develop

- Yet another compilation fix

test=develop

- Yet another compilation fix

test=develo

- Yet another linking fix

test=develop

- Final compilation fix

test=develop

- lint fixes

test=develop

- Lint fixes

test=develop

* - Fixes after review

test=develop
上级 9047ac68
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
......@@ -65,21 +65,27 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
(src_x_tz.size() == 5 &&
x->format() != (format = memory::format::ncdhw))) {
_x.Resize(x_dims);
auto user_x_memory_pd = memory::primitive_desc(
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
auto x_memory_pd = memory::primitive_desc(
{{src_x_tz}, memory::data_type::f32, format}, mkldnn_engine);
auto size = x_memory_pd.get_size();
_x.mutable_data<T>(ctx.GetPlace(), size);
auto user_x_memory =
memory(user_x_memory_pd, paddle::platform::to_void_cast<T>(x_data));
auto x_memory = memory(x_memory_pd,
paddle::platform::to_void_cast<T>(_x.data<T>()));
auto x_reorder = reorder(user_x_memory, x_memory);
mkldnn::memory::data_type in_type = platform::MKLDNNGetDataType<T>();
auto out_format = platform::MKLDNNFormatForSize(
x_dims.size(), mkldnn::memory::format::nchw);
const std::string key = platform::ReorderMKLDNNHandler::GetHash(
src_x_tz, x->format(), out_format, std::to_string(in_type));
platform::ReorderMKLDNNHandler handler(src_x_tz, x->type(), in_type,
dev_ctx, mkldnn_engine, key);
auto user_x_memory_p = handler.AcquireSrcMemory(
x->format(), paddle::platform::to_void_cast(x_data));
auto x_memory_p =
handler.AcquireDstMemory(&_x, out_format, ctx.GetPlace());
auto x_reorder = handler.AcquireReorder(x_memory_p, user_x_memory_p);
std::vector<primitive> pipeline;
pipeline.push_back(x_reorder);
pipeline.push_back(*x_reorder);
stream(stream::kind::eager).submit(pipeline).wait();
} else {
format = x->format();
......@@ -125,46 +131,41 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int> dst_tz = framework::vectorize2int(z_dims);
std::vector<memory::primitive_desc> srcs_pd;
std::vector<memory> srcs;
std::vector<float> scales = {1.0f, 1.0f};
auto src_x_pd = memory::primitive_desc(
{{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine);
auto src_y_pd = memory::primitive_desc(
{{src_y_tz}, memory::data_type::f32, y->format()}, mkldnn_engine);
auto src_x_memory =
memory(src_x_pd, paddle::platform::to_void_cast(x_data));
auto src_y_memory =
memory(src_y_pd, paddle::platform::to_void_cast(y_data));
const std::string key = platform::MKLDNNHandler::GetHash(
src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) +
std::to_string(y->format()));
platform::SumMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
auto src_x_memory = handler.AcquireSrcMemory(
{{src_x_tz}, platform::MKLDNNGetDataType<T>(), x->format()},
paddle::platform::to_void_cast(x_data));
srcs_pd.push_back(src_x_pd);
srcs_pd.push_back(src_y_pd);
srcs.push_back(src_x_memory);
srcs.push_back(src_y_memory);
auto src_y_memory = handler.AcquireSecondSrcMemory(
{{src_y_tz}, platform::MKLDNNGetDataType<T>(), y->format()},
paddle::platform::to_void_cast(y_data));
auto dst_md =
memory::desc({dst_tz}, memory::data_type::f32, memory::format::any);
auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType<T>(),
memory::format::any);
// create primitive descriptor for sum
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
{src_x_memory, src_y_memory}, scales, dst_md);
// create mkldnn memory for dst
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
std::vector<primitive::at> inputs;
inputs.push_back(srcs[0]);
inputs.push_back(srcs[1]);
std::vector<primitive::at> inputs({*src_x_memory, *src_y_memory});
// create sum primitive
auto sum_prim = sum(sum_pd, inputs, dst_memory);
auto sum_prim = handler.AcquireSum(dst_memory, &inputs);
std::vector<primitive> pipeline;
pipeline.push_back(sum_prim);
pipeline.push_back(*sum_prim);
stream(stream::kind::eager).submit(pipeline).wait();
z->set_layout(DataLayout::kMKLDNN);
z->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
(memory::format)dst_memory->get_primitive_desc().desc().data.format);
}
}
};
......
......@@ -45,6 +45,11 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
......@@ -265,6 +270,55 @@ class MKLDNNHandler {
static constexpr int MaxKeyLength = 256;
};
class SumMKLDNNHandler : public MKLDNNHandler {
public:
SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
std::shared_ptr<mkldnn::sum::primitive_desc> AcquireSumPrimitiveDescriptor(
const std::vector<std::shared_ptr<mkldnn::memory>>& src_mems,
const std::vector<float>& scales, const mkldnn::memory::desc& dst_md) {
const std::string key_sum_pd = key_ + "@sum_pd";
sum_pd_ = std::static_pointer_cast<mkldnn::sum::primitive_desc>(
dev_ctx_.GetBlob(key_sum_pd));
if (sum_pd_ == nullptr) {
// Get vector of inputs primitive descriptors
std::vector<mkldnn::memory::primitive_desc> src_pds;
for (auto& input_mem : src_mems) {
src_pds.push_back(input_mem->get_primitive_desc());
}
sum_pd_.reset(new mkldnn::sum::primitive_desc(dst_md, scales, src_pds));
dev_ctx_.SetBlob(key_sum_pd, sum_pd_);
}
return sum_pd_;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(sum_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::sum> AcquireSum(
std::shared_ptr<mkldnn::memory> dst_memory,
std::vector<mkldnn::primitive::at>* inputs) {
auto prim_key = key_ + "@sum_p";
auto sum_p =
std::static_pointer_cast<mkldnn::sum>(dev_ctx_.GetBlob(prim_key));
if (sum_p == nullptr) {
sum_p = std::make_shared<mkldnn::sum>(*(sum_pd_), *inputs, *(dst_memory));
dev_ctx_.SetBlob(prim_key, sum_p);
}
return sum_p;
}
private:
std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_;
};
class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册