未验证 提交 3448afc1 编写于 作者: P Paulina Gacek 提交者: GitHub

[PHI] Sum op migration (#46239)

* Sum kernel migrated to phi

* Static cast added, file name changed

* OneDNNGetDataType to uppercase

* refactoring

* AddOneDNNHandler changed to SumOneDNNHandler
上级 ffc697ff
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,64 +12,39 @@
// See the License for the specific language governing permissions and
// limitations under the License.
/*Licensed under the Apache License, Version 2.0(the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace operators {
using paddle::platform::MKLDNNDeviceContext;
using phi::CPUContext;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
namespace funcs {
template <typename T>
class SumMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::sum> {
class SumOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::sum> {
public:
SumMKLDNNHandler(dnnl::engine engine,
platform::Place cpu_place,
const std::vector<framework::Variable*>& in_vars,
framework::LoDTensor* z)
SumOneDNNHandler(dnnl::engine engine,
const Place& cpu_place,
const std::vector<const TensorBase*>& x,
DenseTensor* out)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
: OneDNNHandlerNoCachingT<T, dnnl::sum>(engine, cpu_place),
num_inputs_(0) {
auto dst_tz = phi::vectorize<int64_t>(z->dims());
auto dst_tz = vectorize<int64_t>(out->dims());
auto src_tz = dst_tz;
std::vector<dnnl::memory::desc> srcs_md;
srcs_md.reserve(in_vars.size());
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
srcs_md.reserve(x.size());
for (size_t i = 0; i < x.size(); i++) {
auto* input_it = (static_cast<const DenseTensor*>(x[i]));
if (input_it->numel() == 0) {
continue;
}
srcs_md.push_back(input_it.mem_desc());
srcs_md.push_back(input_it->mem_desc());
++num_inputs_;
}
std::vector<float> scales(num_inputs_, 1.0f);
auto dst_md = dnnl::memory::desc(
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, scales, srcs_md);
}
......@@ -84,14 +59,14 @@ class SumMKLDNNHandler
new dnnl::sum::primitive_desc(dst_md, scales, srcs_md, this->engine_));
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const framework::Tensor& input,
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor* input,
int i) {
const T* input_data = input.data<T>();
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data));
}
using platform::MKLDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
using OneDNNHandlerNoCachingT<T, dnnl::sum>::AcquireDstMemory;
std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc());
......@@ -102,37 +77,34 @@ class SumMKLDNNHandler
private:
int num_inputs_;
};
template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
} // namespace funcs
template <typename T, typename Context>
void AddNKernel(const Context& dev_ctx,
const std::vector<const TensorBase*>& x,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType() == AllocationType::CPU,
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Sum must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X");
errors::PreconditionNotMet("oneDNN AddN kernel must use CPUPlace"));
const auto& onednn_engine = dev_ctx.GetEngine();
PADDLE_ENFORCE_NE(
in_vars.empty(),
true,
platform::errors::InvalidArgument("Input variable is empty."));
auto& input0 = in_vars[0]->Get<LoDTensor>();
LoDTensor* output = ctx.Output<LoDTensor>("Out");
x.empty(), true, errors::InvalidArgument("Input variable is empty."));
auto* input0 = (static_cast<const DenseTensor*>(x[0]));
bool in_place = (input0.numel() > 0) && input0.IsSharedBufferWith(*output);
bool in_place = (input0->numel() > 0) && input0->IsSharedBufferWith(*out);
SumMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), in_vars, output);
funcs::SumOneDNNHandler<T> handler(onednn_engine, dev_ctx.GetPlace(), x, out);
// Create list of SRC MEMs
std::vector<std::shared_ptr<dnnl::memory>> srcs_mem;
srcs_mem.reserve(handler.GetNumInputs());
int input_index = 0;
for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<framework::LoDTensor>();
if (input_it.numel() == 0) {
for (size_t i = 0; i < x.size(); i++) {
auto* input_it = (static_cast<const DenseTensor*>(x[i]));
if (input_it->numel() == 0) {
continue;
}
srcs_mem.push_back(handler.AcquireSrcMemory(input_it, input_index));
......@@ -140,35 +112,24 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::memory> dst_mem;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs_mem[i])});
}
if (in_place) {
dst_mem = srcs_mem[0];
} else {
dst_mem = handler.AcquireDstMemory(output);
}
auto dst_mem = in_place ? srcs_mem[0] : handler.AcquireDstMemory(out);
args.insert({DNNL_ARG_DST, *dst_mem});
auto sum_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto& astream = OneDNNContext::tls().get_stream();
sum_p->execute(astream, args);
astream.wait();
output->set_mem_desc(dst_mem->get_desc());
}
};
} // namespace operators
} // namespace paddle
out->set_mem_desc(dst_mem->get_desc());
}
} // namespace phi
REGISTER_OP_KERNEL(
sum,
MKLDNN,
::paddle::platform::CPUPlace,
paddle::operators::SumMKLDNNOpKernel<paddle::platform::bfloat16>,
paddle::operators::SumMKLDNNOpKernel<float>);
PD_REGISTER_KERNEL(
add_n, OneDNN, ALL_LAYOUT, phi::AddNKernel, float, phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册