未验证 提交 3448afc1 编写于 作者: P Paulina Gacek 提交者: GitHub

[PHI] Sum op migration (#46239)

* Sum kernel migrated to phi

* Static cast added, file name changed

* OneDNNGetDataType to uppercase

* refactoring

* AddOneDNNHandler changed to SumOneDNNHandler
上级 ffc697ff
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,64 +12,39 @@ ...@@ -12,64 +12,39 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
/*Licensed under the Apache License, Version 2.0(the "License"); #include "paddle/phi/kernels/add_n_kernel.h"
you may not use this file except in compliance with the License. #include "paddle/phi/backends/onednn/onednn_reuse.h"
You may obtain a copy of the License at #include "paddle/phi/core/kernel_registry.h"
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace phi { namespace phi {
class DenseTensor; namespace funcs {
} // namespace phi
namespace paddle {
namespace operators {
using paddle::platform::MKLDNNDeviceContext;
using phi::CPUContext;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
class SumMKLDNNHandler class SumOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::sum> {
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
public: public:
SumMKLDNNHandler(dnnl::engine engine, SumOneDNNHandler(dnnl::engine engine,
platform::Place cpu_place, const Place& cpu_place,
const std::vector<framework::Variable*>& in_vars, const std::vector<const TensorBase*>& x,
framework::LoDTensor* z) DenseTensor* out)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place), : OneDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
num_inputs_(0) { num_inputs_(0) {
auto dst_tz = phi::vectorize<int64_t>(z->dims()); auto dst_tz = vectorize<int64_t>(out->dims());
auto src_tz = dst_tz; auto src_tz = dst_tz;
std::vector<dnnl::memory::desc> srcs_md; std::vector<dnnl::memory::desc> srcs_md;
srcs_md.reserve(in_vars.size()); srcs_md.reserve(x.size());
for (size_t i = 0; i < in_vars.size(); i++) { for (size_t i = 0; i < x.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>(); auto* input_it = (static_cast<const DenseTensor*>(x[i]));
if (input_it.numel() == 0) { if (input_it->numel() == 0) {
continue; continue;
} }
srcs_md.push_back(input_it.mem_desc()); srcs_md.push_back(input_it->mem_desc());
++num_inputs_; ++num_inputs_;
} }
std::vector<float> scales(num_inputs_, 1.0f); std::vector<float> scales(num_inputs_, 1.0f);
auto dst_md = dnnl::memory::desc( auto dst_md = dnnl::memory::desc(
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any); dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md); this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
} }
...@@ -84,14 +59,14 @@ class SumMKLDNNHandler ...@@ -84,14 +59,14 @@ class SumMKLDNNHandler
new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_)); new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
} }
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input, std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor* input,
int i) { int i) {
const T* input_data = input.data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
} }
using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory; using OneDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
std::shared_ptr<dnnl::memory> AcquireDstMemory(void) { std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
...@@ -102,73 +77,59 @@ class SumMKLDNNHandler ...@@ -102,73 +77,59 @@ class SumMKLDNNHandler
private: private:
int num_inputs_; int num_inputs_;
}; };
} // namespace funcs
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { template <typename T, typename Context>
public: void AddNKernel(const Context& dev_ctx,
void Compute(const paddle::framework::ExecutionContext& ctx) const override { const std::vector<const TensorBase*>& x,
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), DenseTensor* out) {
true, PADDLE_ENFORCE_EQ(
paddle::platform::errors::PreconditionNotMet( dev_ctx.GetPlace().GetType() == AllocationType::CPU,
"Operator DNNL Sum must use CPUPlace")); true,
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); errors::PreconditionNotMet("oneDNN AddN kernel must use CPUPlace"));
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X"); const auto& onednn_engine = dev_ctx.GetEngine();
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in_vars.empty(), x.empty(), true, errors::InvalidArgument("Input variable is empty."));
true, auto* input0 = (static_cast<const DenseTensor*>(x[0]));
platform::errors::InvalidArgument("Input variable is empty."));
auto& input0 = in_vars[0]->Get<LoDTensor>(); bool in_place = (input0->numel() > 0) && input0->IsSharedBufferWith(*out);
LoDTensor* output = ctx.Output<LoDTensor>("Out");
funcs::SumOneDNNHandler<T> handler(onednn_engine, dev_ctx.GetPlace(), x, out);
bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
// Create list of SRC MEMs
SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output); std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
srcs_mem.reserve(handler.GetNumInputs());
// Create list of SRC MEMs int input_index = 0;
std::vector<std::shared_ptr<dnnl::memory>> srcs_mem; for (size_t i = 0; i < x.size(); i++) {
srcs_mem.reserve(handler.GetNumInputs()); auto* input_it = (static_cast<const DenseTensor*>(x[i]));
int input_index = 0; if (input_it->numel() == 0) {
for (size_t i = 0; i < in_vars.size(); i++) { continue;
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
continue;
}
srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
++input_index;
} }
srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
++input_index;
}
std::unordered_map<int, dnnl::memory> args; std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::memory> dst_mem;
for (size_t i = 0; i < srcs_mem.size(); ++i) { for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])}); args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
} }
if (in_place) { auto dst_mem = in_place ? srcs_mem[0] : handler.AcquireDstMemory(out);
dst_mem = srcs_mem[0];
} else {
dst_mem = handler.AcquireDstMemory(output);
}
args.insert({DNNL_ARG_DST, *dst_mem});
auto sum_p = handler.AcquireForwardPrimitive(); args.insert({DNNL_ARG_DST, *dst_mem});
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto sum_p = handler.AcquireForwardPrimitive();
sum_p->execute(astream, args);
astream.wait();
output->set_mem_desc(dst_mem->get_desc()); auto& astream = OneDNNContext::tls().get_stream();
} sum_p->execute(astream, args);
}; astream.wait();
} // namespace operators out->set_mem_desc(dst_mem->get_desc());
} // namespace paddle }
} // namespace phi
REGISTER_OP_KERNEL( PD_REGISTER_KERNEL(
sum, add_n, OneDNN, ALL_LAYOUT, phi::AddNKernel, float, phi::dtype::bfloat16) {}
MKLDNN,
::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
paddle::operators::SumMKLDNNOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册