未验证 提交 84cc61b2 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] sum op refactor (#28318)

上级 6f0f45f6
......@@ -25,7 +25,7 @@
limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace framework {
......@@ -51,6 +51,95 @@ using paddle::platform::CPUDeviceContext;
using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast;
template <typename T>
class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
public:
SumMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::vector<framework::Variable*>& in_vars,
framework::LoDTensor* z, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::sum>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(z->dims()), uniq_name)),
num_inputs_(0) {
for (size_t i = 0; i < in_vars.size(); i++) {
srcs_suffix_.push_back(std::string("-") + std::to_string(i));
}
if (!this->isCached()) {
auto dst_tz = framework::vectorize<int64_t>(z->dims());
auto src_tz = dst_tz;
std::vector<memory::desc> srcs_md;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
MKLDNNMemoryFormat input_format = input_it.format();
srcs_md.push_back(memory::desc(src_tz, platform::MKLDNNGetDataType<T>(),
input_format));
++num_inputs_;
}
std::vector<float> scales(num_inputs_, 1.0);
auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
}
}
// (jczaja) sum oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md, const std::vector<float>& scales,
const std::vector<memory::desc>& srcs_md) {
// Sum op does not have backward so no passing from FWD to BWD is needed
const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
if (this->fwd_pd_ == nullptr) {
this->fwd_pd_.reset(new mkldnn::sum::primitive_desc(
dst_md, scales, srcs_md, this->engine_));
this->dev_ctx_.SetBlob(key_pd, this->fwd_pd_);
}
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor& input, int i) {
const T* input_data = input.data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data),
"@src_mem_p" + srcs_suffix_[i]);
}
using platform::MKLDNNHandlerT<T, dnnl::sum>::AcquireDstMemory;
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(),
"@dst_mem_p");
}
inline int GetNumInputs(void) { return num_inputs_; }
protected:
// isCached need to be overloaded as base one works on key_common
bool isCached() {
const std::string key_pd = this->key_ + "@fwd_pd";
this->fwd_pd_ = std::static_pointer_cast<dnnl::sum::primitive_desc>(
this->dev_ctx_.GetBlob(key_pd));
const std::string key_p = this->key_ + "@fwd_p";
return (this->dev_ctx_.GetBlob(key_p) != nullptr);
}
private:
int num_inputs_;
std::vector<std::string> srcs_suffix_;
};
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -59,85 +148,67 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Sum must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X");
auto out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE_NE(in_vars.empty(), true, platform::errors::InvalidArgument(
"Input variable is empty."));
bool in_place = out_var == in_vars[0];
auto& input0 = in_vars[0]->Get<LoDTensor>();
LoDTensor* output = ctx.Output<LoDTensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto dst_tz = framework::vectorize<int64_t>(output->dims());
auto src_tz = dst_tz;
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::undef};
std::vector<float> scales;
std::vector<memory::desc> srcs_md;
std::vector<mkldnn::memory> srcs_mem;
bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
auto& input0 = in_vars[0]->Get<LoDTensor>();
in_place = (input0.numel() > 0) && (input0.data<T>() == output_data);
SumMKLDNNHandler<T> handler(dev_ctx, ctx.GetPlace(), in_vars, output,
ctx.OutputName("Out"));
// Create list of SRC MEMs
std::vector<std::shared_ptr<mkldnn::memory>> srcs_mem;
srcs_mem.reserve(handler.GetNumInputs());
int input_index = 0;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<LoDTensor>();
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
const T* input_data = input_it.data<T>();
MKLDNNMemoryFormat input_format = input_it.format();
auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format);
auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data));
srcs_md.push_back(src_md);
srcs_mem.push_back(src_mem);
scales.push_back(1.0);
}
auto dst_md =
memory::desc(dst_tz, memory::data_type::f32, MKLDNNMemoryFormat::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_md, mkldnn_engine);
std::shared_ptr<memory> dst_mem;
if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine));
} else {
dst_mem.reset(new memory(sum_pd.dst_desc(), mkldnn_engine, output_data));
srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
++input_index;
}
auto sum_prim = mkldnn::sum(sum_pd);
output_format = platform::GetMKLDNNFormat(sum_pd.dst_desc());
auto dst_mem = in_place ? handler.AcquireDstMemory()
: handler.AcquireDstMemory(output);
std::shared_ptr<mkldnn::reorder> reorder_p;
std::shared_ptr<memory> target_mem;
if (in_place) {
output_format = input0.format();
target_mem.reset(
new memory({{src_tz}, memory::data_type::f32, output_format},
mkldnn_engine, output_data));
reorder_p = std::make_shared<reorder>(*dst_mem, *target_mem);
}
auto sum_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(mkldnn_engine);
std::unordered_map<int, memory> args;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, srcs_mem.at(i)});
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
args.insert({MKLDNN_ARG_DST, *dst_mem});
sum_prim.execute(astream, args);
mkldnn::stream astream(dev_ctx.GetEngine());
sum_p->execute(astream, args);
astream.wait();
// For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input
if (in_place) {
const std::string reorder_key = platform::CreateKey(
framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I");
auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
auto output_tz = framework::vectorize<int64_t>(output->dims());
platform::ReorderMKLDNNHandler reorder_handler(
output_tz, output->type(), framework::ToMKLDNNDataType(in_out.type()),
dev_ctx, dev_ctx.GetEngine(), reorder_key);
auto target_mem = reorder_handler.AcquireDstMemory(
output, in_out.format(), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(target_mem, dst_mem);
reorder_p->execute(astream, *dst_mem, *target_mem);
astream.wait();
}
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
}
};
......
......@@ -591,59 +591,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
}
};
class SumMKLDNNHandler : public MKLDNNHandler {
public:
SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
std::shared_ptr<mkldnn::sum::primitive_desc> AcquireSumPrimitiveDescriptor(
const std::vector<std::shared_ptr<mkldnn::memory>>& src_mems,
const std::vector<float>& scales, const mkldnn::memory::desc& dst_md) {
const std::string key_sum_pd = key_ + "@sum_pd";
sum_pd_ = std::static_pointer_cast<mkldnn::sum::primitive_desc>(
dev_ctx_.GetBlob(key_sum_pd));
if (sum_pd_ == nullptr) {
// Get vector of inputs primitive descriptors
std::vector<mkldnn::memory::desc> src_ds;
for (auto& input_mem : src_mems) {
src_ds.push_back(input_mem->get_desc());
}
sum_pd_.reset(
new mkldnn::sum::primitive_desc(dst_md, scales, src_ds, engine_));
dev_ctx_.SetBlob(key_sum_pd, sum_pd_);
}
return sum_pd_;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(sum_pd_->dst_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
}
std::shared_ptr<mkldnn::sum> AcquireSum() {
auto prim_key = key_ + "@sum_p";
auto sum_p =
std::static_pointer_cast<mkldnn::sum>(dev_ctx_.GetBlob(prim_key));
if (sum_p == nullptr) {
sum_p = std::make_shared<mkldnn::sum>(*sum_pd_);
dev_ctx_.SetBlob(prim_key, sum_p);
}
return sum_p;
}
private:
std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_;
};
template <typename T>
class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward,
......
......@@ -86,4 +86,6 @@ class TestMKLDNNSumInplaceOp(unittest.TestCase):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册