未验证 提交 3d59fee5 编写于 作者: P Piotr Paturej 提交者: GitHub

[PHI] Migrate concat+grad, expand+grad, fill_constant, nearest_interp and...

[PHI] Migrate concat+grad, expand+grad, fill_constant, nearest_interp and bilinear_interp oneDNN kernels (#45863)

* Migrate concat+grad, expand+grad, fill_constant, nearest_interp_v2 and bilinear_interp_v2 oneDNN kernels to PHI

* Remove old namespace variable

* Fix invalid out dims error

* Add mutable_data method to concat output

* Add check for -1 dim before computing out_dims

* Capitalize oneDNNGetDataType function name

* Change fill_constant kernel to correct PHI kernel

* Attempt to fix dims error

* Fix fill_constant (full) kernel
上级 c9a7a3bc
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using dnnl::concat;
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
using framework::DataLayout;
using framework::LoDTensor;
using framework::Tensor;
using platform::to_void_cast;
template <typename T>
class ConcatMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::concat> {
public:
ConcatMKLDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine mkldnn_engine,
const std::vector<const Tensor*>& inputs,
Tensor* output)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::concat>(mkldnn_engine,
ctx.GetPlace()) {
int concat_axis = ctx.Attr<int>("axis");
const int rank = inputs[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank,
true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
concat_axis));
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<Tensor>("AxisTensor");
concat_axis = GetDataFromTensor(axis_tensor)[0];
auto out_dims = inputs[0]->dims();
for (size_t i = 1; i < inputs.size(); ++i) {
out_dims[concat_axis] += inputs[i]->dims()[concat_axis];
}
output->Resize(out_dims);
}
if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}
memory::data_type dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(inputs[0]->dtype()));
std::vector<memory::desc> srcs_md;
srcs_md.reserve(inputs.size());
// Create memory descriptors for each of inputs
for (size_t i = 0; i < inputs.size(); ++i) {
srcs_md.push_back(inputs[i]->mem_desc());
}
auto dst_dims = phi::vectorize<int64_t>(output->dims());
dnnl::memory::desc dst_md =
memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md);
}
// (jczaja) concat oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md,
const int concat_axis,
const std::vector<memory::desc>& srcs_md) {
this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
dst_md, concat_axis, srcs_md, this->engine_));
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const Tensor& input, int i) {
const T* input_data = input.data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data));
}
};
static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(
input->layout(),
DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Input tensor"));
}
}
// From a multi-input, gather only nonempty inputs
static const std::vector<const Tensor*> ReduceMultiInput(
const std::vector<const Tensor*>& inputs) {
std::vector<const Tensor*> reduced(inputs.size());
auto end_it = std::copy_if(
inputs.begin(), inputs.end(), reduced.begin(), [](const Tensor* t) {
return t->numel() > 0;
});
reduced.resize(std::distance(reduced.begin(), end_it));
return reduced;
}
template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
// If any of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
ConcatMKLDNNHandler<T> handler(ctx, mkldnn_engine, multi_input, output);
std::vector<std::shared_ptr<memory>> srcs;
srcs.reserve(multi_input.size());
auto dst_mem = handler.AcquireDstMemory(output);
auto concat_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, memory> args;
for (size_t i = 0; i < multi_input.size(); ++i) {
srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i));
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))});
}
args.insert({DNNL_ARG_DST, *dst_mem});
concat_p->execute(astream, args);
astream.wait();
output->set_mem_desc(dst_mem->get_desc());
}
};
template <typename T>
class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
const auto x = ctx.MultiInput<LoDTensor>("X");
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto dx = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<Tensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
}
auto dout_vec_dims = phi::vectorize(dout->dims());
axis = ComputeAxis(axis, dout_vec_dims.size());
std::vector<int64_t> offset(dout_vec_dims.size(), 0);
dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims,
framework::TransToProtoVarType(dout->dtype()),
dout_type,
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
for (size_t i = 0; i < dx.size(); ++i) {
if (out_var_names[i] != framework::kEmptyVarName &&
dx[i]->numel() != 0UL) {
auto dx_vec_dims = phi::vectorize(dx[i]->dims());
auto slice_mem_p = reorder_handler.AcquireSubmemory(
dx_vec_dims, offset, reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx[i],
dx_vec_dims,
platform::GetPlainMKLDNNFormat(dx_vec_dims.size()),
ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);
offset[axis] += dx[i]->dims()[axis];
dx[i]->set_mem_desc(reorder_dst_memory_p->get_desc());
}
}
astream.wait();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(concat,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConcatMKLDNNOpKernel<float>,
ops::ConcatMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConcatMKLDNNOpKernel<int8_t>,
ops::ConcatMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL(concat_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::ConcatGradMKLDNNOpKernel<float>,
ops::ConcatGradMKLDNNOpKernel<paddle::platform::bfloat16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/expand_v2_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace {
using paddle::framework::ExecutionContext;
using paddle::framework::GradVarName;
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using phi::vectorize;
template <typename T>
class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto x_vec_dims = vectorize(x->dims());
auto out_new_dims = paddle::operators::get_expand_shape(ctx);
for (size_t i = 0; i < out_new_dims.size(); ++i) {
out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i];
}
if (x_vec_dims.size() != out_new_dims.size()) {
x_vec_dims = GetExtendedXDims(x_vec_dims, out_new_dims.size());
}
out->Resize(phi::make_ddim(out_new_dims));
paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add,
onednn_engine,
ctx.GetPlace(),
x,
out,
0.0f,
1.0f,
x_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireZeroedDstMemory(out);
auto binary_p = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *dst_memory_p},
{DNNL_ARG_SRC_1, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = MKLDNNDeviceContext::tls().get_stream();
binary_p->execute(astream, args);
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
private:
std::vector<int64_t> GetExtendedXDims(const std::vector<int64_t>& x_vec_dims,
int new_size) const {
std::vector<int64_t> extended_x_dims(new_size, 1);
std::copy(x_vec_dims.begin(),
x_vec_dims.end(),
extended_x_dims.begin() + new_size - x_vec_dims.size());
return extended_x_dims;
}
};
template <typename T>
class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(GradVarName("X"));
auto dx_vec_dims = vectorize(dx->dims());
auto dout_vec_dims = vectorize(dout->dims());
if (dx_vec_dims.size() != dout_vec_dims.size()) {
dx_vec_dims.insert(
dx_vec_dims.begin(), dout_vec_dims.size() - dx_vec_dims.size(), 1);
}
auto& astream = MKLDNNDeviceContext::tls().get_stream();
if (dout_vec_dims == dx_vec_dims) {
dnnl::memory::data_type dout_type = paddle::framework::ToMKLDNNDataType(
paddle::framework::TransToProtoVarType(dout->dtype()));
paddle::platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims,
paddle::framework::TransToProtoVarType(dout->dtype()),
dout_type,
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), paddle::platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx,
paddle::platform::GetPlainMKLDNNFormat(dx_vec_dims.size()),
ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_mem_desc(reorder_dst_memory_p->get_desc());
} else {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
dout,
dx,
dx_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(dout);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_layout(paddle::framework::DataLayout::kMKLDNN);
dx->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(dx->dims())));
}
}
};
} // anonymous namespace
REGISTER_OP_KERNEL(expand_v2,
MKLDNN,
paddle::platform::CPUPlace,
ExpandMKLDNNKernel<float>,
ExpandMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(expand_v2_grad,
MKLDNN,
paddle::platform::CPUPlace,
ExpandGradMKLDNNKernel<float>,
ExpandGradMKLDNNKernel<paddle::platform::bfloat16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
class FillConstantMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
FillConstantMKLDNNHandler(Tensor* out,
dnnl::engine engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_md =
dnnl::memory::desc({out->numel(), sizeof(T)},
platform::MKLDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attrs;
attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f});
this->AcquireForwardPrimitiveDescriptor(
attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md);
}
static const dnnl::memory::desc src1_md;
};
template <typename T>
const dnnl::memory::desc FillConstantMKLDNNHandler<T>::src1_md(
{1, sizeof(T)},
platform::MKLDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab);
template <typename T>
class FillConstantMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& dnnl_engine = dev_ctx.GetEngine();
auto* out = ctx.Output<Tensor>("Out");
T fill_value = CalculateFillValue(ctx);
auto shape = GetShape(ctx);
out->Resize(shape);
FillConstantMKLDNNHandler<T> handler(out, dnnl_engine, ctx.GetPlace());
dnnl::memory constant_value_memory =
dnnl::memory(FillConstantMKLDNNHandler<T>::src1_md,
dnnl_engine,
reinterpret_cast<uint8_t*>(&fill_value));
auto src0_memory_p = handler.AcquireDstMemory(out);
auto fill_constant_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
fill_constant_p->execute(astream,
{{DNNL_ARG_SRC_0, *src0_memory_p},
{DNNL_ARG_SRC_1, constant_value_memory},
{DNNL_ARG_DST, *src0_memory_p}});
astream.wait();
// src0_memory_p's md was just to allow the usage of a binary
// primitive as a memset, and now we need to create a real one
out->set_mem_desc({phi::vectorize(shape),
platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(shape.size())});
}
T CalculateFillValue(const framework::ExecutionContext& ctx) const {
const auto str_value = ctx.Attr<std::string>("str_value");
const auto float_value = ctx.Attr<float>("value");
T value;
if (str_value.empty()) {
value = static_cast<T>(float_value);
} else {
// handle NaN/Inf first, which cannot be read from stream
if (str_value == "inf") {
value = static_cast<T>(std::numeric_limits<float>::infinity());
} else if (str_value == "-inf") {
value = static_cast<T>(-std::numeric_limits<float>::infinity());
} else if (str_value == "nan") {
value = static_cast<T>(std::numeric_limits<float>::quiet_NaN());
} else {
std::stringstream convert_stream(str_value);
double tmp_value;
convert_stream >> tmp_value;
value = static_cast<T>(tmp_value);
}
}
if (ctx.HasInput("ValueTensor")) {
const auto* value_tensor = ctx.Input<Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ(
value_tensor->numel(),
1,
platform::errors::InvalidArgument(
"When use Tensor as value to set Tensor value in fill_constant, "
"value input(ValueTensor) size must be 1, but got %d",
value_tensor->numel()));
value = value_tensor->data<T>()[0];
}
return value;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fill_constant,
MKLDNN,
paddle::platform::CPUPlace,
ops::FillConstantMKLDNNKernel<float>);
......@@ -181,15 +181,3 @@ REGISTER_OP_KERNEL(bilinear_interp,
MKLDNN,
::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
REGISTER_OP_KERNEL(nearest_interp_v2,
MKLDNN,
::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<paddle::platform::bfloat16>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2,
MKLDNN,
::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/concat_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi {
template <typename T, typename Context>
void ConcatGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const DenseTensor& out_grad,
const Scalar& axis_scalar,
std::vector<DenseTensor*> x_grad) {
const auto& onednn_engine = dev_ctx.GetEngine();
auto& astream = OneDNNContext::tls().get_stream();
for (size_t i = 0; i < x_grad.size(); ++i) {
if (x_grad[i] != nullptr) {
x_grad[i]->set_lod(x[i]->lod());
}
}
int axis = axis_scalar.to<int>();
auto out_grad_vec_dims = vectorize(out_grad.dims());
axis = funcs::ComputeAxis(axis, out_grad_vec_dims.size());
std::vector<int64_t> offset(out_grad_vec_dims.size(), 0);
dnnl::memory::data_type out_grad_type =
funcs::ToOneDNNDataType(out_grad.dtype());
funcs::ReorderOneDNNHandler reorder_handler(
out_grad_vec_dims, out_grad.dtype(), out_grad_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
out_grad.mem_desc(), funcs::to_void_cast(out_grad.data<T>()));
for (size_t i = 0; i < x_grad.size(); ++i) {
if (x_grad[i]->numel() != 0UL) {
auto x_grad_vec_dims = vectorize(x_grad[i]->dims());
auto slice_mem_p = reorder_handler.AcquireSubmemory(
x_grad_vec_dims, offset, reorder_src_memory_p);
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
x_grad[i],
x_grad_vec_dims,
funcs::GetPlainOneDNNFormat(x_grad_vec_dims.size()),
dev_ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);
offset[axis] += x_grad[i]->dims()[axis];
x_grad[i]->set_mem_desc(reorder_dst_memory_p->get_desc());
}
}
astream.wait();
}
} // namespace phi
PD_REGISTER_KERNEL(concat_grad,
OneDNN,
ALL_LAYOUT,
phi::ConcatGradKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi {
using memory = dnnl::memory;
namespace funcs {
template <typename T>
class ConcatOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::concat> {
public:
ConcatOneDNNHandler(Place cpu_place,
int concat_axis,
const dnnl::engine onednn_engine,
const std::vector<const DenseTensor*>& inputs,
DenseTensor* output)
: OneDNNHandlerNoCachingT<T, dnnl::concat>(onednn_engine, cpu_place) {
const int rank = inputs[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank,
true,
errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
concat_axis));
if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}
memory::data_type dt = ToOneDNNDataType(inputs[0]->dtype());
std::vector<memory::desc> srcs_md;
srcs_md.reserve(inputs.size());
// Create memory descriptors for each of inputs
for (size_t i = 0; i < inputs.size(); ++i) {
srcs_md.push_back(inputs[i]->mem_desc());
}
auto dst_dims = vectorize<int64_t>(output->dims());
memory::desc dst_md = memory::desc(dst_dims, dt, OneDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(dst_md, concat_axis, srcs_md);
}
// (jczaja) concat oneDNN prim is not having .desc attribute so
// we cannot use base AcquireForwardPrimitiveDescriptor
void AcquireForwardPrimitiveDescriptor(
const memory::desc& dst_md,
const int concat_axis,
const std::vector<memory::desc>& srcs_md) {
this->fwd_pd_.reset(new dnnl::concat::primitive_desc(
dst_md, concat_axis, srcs_md, this->engine_));
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const DenseTensor& input,
int i) {
const T* input_data = input.data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src_desc(i),
to_void_cast<T>(input_data));
}
};
} // namespace funcs
static void EnforceLayouts(const std::vector<const DenseTensor*> inputs) {
for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(
input->layout(),
DataLayout::ONEDNN,
errors::InvalidArgument("Wrong layout set for Input tensor"));
}
}
// From a multi-input, gather only nonempty inputs
static const std::vector<const DenseTensor*> ReduceMultiInput(
const std::vector<const DenseTensor*>& inputs) {
std::vector<const DenseTensor*> reduced(inputs.size());
auto end_it = std::copy_if(
inputs.begin(), inputs.end(), reduced.begin(), [](const DenseTensor* t) {
return t->numel() > 0;
});
reduced.resize(std::distance(reduced.begin(), end_it));
return reduced;
}
template <typename T, typename Context>
void ConcatKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();
// If any of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
auto multi_input = ReduceMultiInput(x);
EnforceLayouts(multi_input);
auto out_dims_vec = vectorize(out->dims());
if (std::any_of(out_dims_vec.begin(), out_dims_vec.end(), [](int64_t i) {
return i < 0;
})) {
std::vector<phi::DDim> x_dims;
x_dims.reserve(x.size());
for (size_t i = 0; i < x.size(); ++i) {
x_dims.push_back(x[i]->dims());
}
DDim out_dims =
funcs::ComputeAndCheckShape(true, x_dims, axis.to<size_t>());
out->Resize(out_dims);
}
funcs::ConcatOneDNNHandler<T> handler(
dev_ctx.GetPlace(), axis.to<int>(), onednn_engine, multi_input, out);
std::vector<std::shared_ptr<memory>> srcs;
srcs.reserve(multi_input.size());
auto dst_mem = handler.AcquireDstMemory(out);
auto concat_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
std::unordered_map<int, memory> args;
for (size_t i = 0; i < multi_input.size(); ++i) {
srcs.push_back(handler.AcquireSrcMemory(*(multi_input[i]), i));
args.insert({DNNL_ARG_MULTIPLE_SRC + i, *(srcs.at(i))});
}
args.insert({DNNL_ARG_DST, *dst_mem});
concat_p->execute(astream, args);
astream.wait();
out->set_mem_desc(dst_mem->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(concat,
OneDNN,
ALL_LAYOUT,
phi::ConcatKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ExpandGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& shape,
DenseTensor* in_grad) {
const auto& onednn_engine = dev_ctx.GetEngine();
auto in_grad_vec_dims = vectorize(in_grad->dims());
auto out_grad_vec_dims = vectorize(out_grad.dims());
if (in_grad_vec_dims.size() != out_grad_vec_dims.size()) {
in_grad_vec_dims.insert(in_grad_vec_dims.begin(),
out_grad_vec_dims.size() - in_grad_vec_dims.size(),
1);
}
auto& astream = OneDNNContext::tls().get_stream();
if (out_grad_vec_dims == in_grad_vec_dims) {
dnnl::memory::data_type out_grad_type =
funcs::ToOneDNNDataType(out_grad.dtype());
funcs::ReorderOneDNNHandler reorder_handler(
out_grad_vec_dims, out_grad.dtype(), out_grad_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
out_grad.mem_desc(), funcs::to_void_cast(out_grad.data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
in_grad,
funcs::GetPlainOneDNNFormat(in_grad_vec_dims.size()),
dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
in_grad->set_mem_desc(reorder_dst_memory_p->get_desc());
} else {
funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
onednn_engine,
dev_ctx.GetPlace(),
&out_grad,
in_grad,
in_grad_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(&out_grad);
auto dst_memory_p = handler.AcquireDstMemory(in_grad);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
in_grad->set_layout(DataLayout::ONEDNN);
in_grad->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(in_grad->dims())));
}
}
} // namespace phi
PD_REGISTER_KERNEL(expand_grad,
OneDNN,
ALL_LAYOUT,
phi::ExpandGradKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
std::vector<int64_t> GetExtendedXDims(const std::vector<int64_t>& x_vec_dims,
int new_size) {
std::vector<int64_t> extended_x_dims(new_size, 1);
std::copy(x_vec_dims.begin(),
x_vec_dims.end(),
extended_x_dims.begin() + new_size - x_vec_dims.size());
return extended_x_dims;
}
template <typename T, typename Context>
void ExpandKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();
auto x_vec_dims = vectorize(x.dims());
auto out_new_dims = shape.GetData();
for (size_t i = 0; i < out_new_dims.size(); ++i) {
out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i];
}
if (x_vec_dims.size() != out_new_dims.size()) {
x_vec_dims = GetExtendedXDims(x_vec_dims, out_new_dims.size());
}
out->Resize(make_ddim(out_new_dims));
funcs::BroadcastDataOneDNNHandler<T> handler(dnnl::algorithm::binary_add,
onednn_engine,
dev_ctx.GetPlace(),
&x,
out,
0.0f,
1.0f,
x_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireZeroedDstMemory(out);
auto binary_p = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *dst_memory_p},
{DNNL_ARG_SRC_1, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
binary_p->execute(astream, args);
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(expand,
OneDNN,
ALL_LAYOUT,
phi::ExpandKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace funcs {
template <typename T>
class FillConstantOneDNNHandler
: public OneDNNHandlerNoCachingT<T, dnnl::binary> {
public:
FillConstantOneDNNHandler(DenseTensor* out,
dnnl::engine engine,
Place cpu_place)
: OneDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
const auto src0_md = dnnl::memory::desc({out->numel(), sizeof(T)},
OneDNNGetDataType<uint8_t>(),
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attrs;
attrs.set_scales(DNNL_ARG_SRC_0, /* mask = */ 0, {0.0f});
this->AcquireForwardPrimitiveDescriptor(
attrs, dnnl::algorithm::binary_add, src0_md, src1_md, src0_md);
}
static const dnnl::memory::desc src1_md;
};
template <typename T>
const dnnl::memory::desc FillConstantOneDNNHandler<T>::src1_md(
{1, sizeof(T)}, OneDNNGetDataType<uint8_t>(), dnnl::memory::format_tag::ab);
} // namespace funcs
template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
const IntArray& shape,
const Scalar& val,
DataType dtype,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();
T fill_value = val.to<T>();
out->Resize(make_ddim(shape.GetData()));
funcs::FillConstantOneDNNHandler<T> handler(
out, onednn_engine, dev_ctx.GetPlace());
dnnl::memory constant_value_memory =
dnnl::memory(funcs::FillConstantOneDNNHandler<T>::src1_md,
onednn_engine,
reinterpret_cast<uint8_t*>(&fill_value));
auto src0_memory_p = handler.AcquireDstMemory(out);
auto fill_constant_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
fill_constant_p->execute(astream,
{{DNNL_ARG_SRC_0, *src0_memory_p},
{DNNL_ARG_SRC_1, constant_value_memory},
{DNNL_ARG_DST, *src0_memory_p}});
astream.wait();
// src0_memory_p's md was just to allow the usage of a binary
// primitive as a memset, and now we need to create a real one
out->set_mem_desc({vectorize(out->dims()),
funcs::OneDNNGetDataType<T>(),
funcs::GetPlainOneDNNFormat(out->dims().size())});
}
} // namespace phi
PD_REGISTER_KERNEL(full, OneDNN, ALL_LAYOUT, phi::FullKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/interpolate_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/interpolate_function.h"
namespace phi {
namespace funcs {
template <typename T = float>
class InterpolateOneDNNHandler
: public OneDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
public:
InterpolateOneDNNHandler(const dnnl::algorithm algo,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
DenseTensor* out)
: OneDNNHandlerNoCachingT<T, dnnl::resampling_forward>(engine,
cpu_place) {
const auto dst_tz = vectorize(out->dims());
const auto dst_md = dnnl::memory::desc(
dst_tz, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, algo, x->mem_desc(), dst_md);
}
};
} // namespace funcs
std::vector<int> ComputeOutputShape(
const DenseTensor* x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale_attr) {
const auto& in_dims = x->dims();
const DDim in_dhw_dims = slice_ddim(in_dims, 2, in_dims.size());
std::vector<int> out_dims;
out_dims.reserve(5);
if (in_dhw_dims.size() == 1) {
out_dims.push_back(out_w);
} else if (in_dhw_dims.size() == 2) {
out_dims.push_back(out_h);
out_dims.push_back(out_w);
} else if (in_dhw_dims.size() == 3) {
out_dims.push_back(out_d);
out_dims.push_back(out_h);
out_dims.push_back(out_w);
}
if (size_tensor && size_tensor.get().size() > 0) {
auto new_size = funcs::get_new_shape(size_tensor.get());
if (new_size.size() == out_dims.size()) {
out_dims = new_size;
}
} else if (out_size) {
auto out_size_data =
funcs::get_new_data_from_tensor<int>(out_size.get_ptr());
if (out_size_data.size() == out_dims.size()) {
out_dims = out_size_data;
}
} else {
std::vector<float> scale;
scale.reserve(3);
if (scale_tensor) {
auto scale_data =
funcs::get_new_data_from_tensor<float>(scale_tensor.get_ptr());
scale.resize(3, scale_data[0]);
std::copy(scale_data.begin(), scale_data.end(), scale.begin());
} else {
if (scale_attr.size() > 0) {
scale.resize(3, scale_attr[0]);
std::copy(scale_attr.begin(), scale_attr.end(), scale.begin());
}
}
if (scale.size() == 3 && scale[0] > 0.0f && scale[1] > 0.0f &&
scale[2] > 0.0f) {
int j = 0;
std::vector<int64_t> in_dhw_vec = vectorize(in_dhw_dims);
std::transform(
in_dhw_vec.begin(),
in_dhw_vec.end(),
out_dims.begin(),
[&](int64_t i) -> int { return static_cast<int>(i * scale[j++]); });
}
}
PADDLE_ENFORCE_GT(
std::all_of(
out_dims.begin(), out_dims.end(), [](int i) { return i > 0; }),
0,
errors::InvalidArgument("out_d, out_h, out_w of Op(interpolate) "
"should be greater than 0."));
const std::vector<int64_t> nc_dims = {in_dims[0], in_dims[1]};
out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end());
return out_dims;
}
template <typename T, typename Context>
void InterpolateKernel(
const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();
const dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;
const auto out_dims_vec = ComputeOutputShape(&x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale);
DDim dim_out = make_ddim(out_dims_vec);
out->Resize(dim_out);
funcs::InterpolateOneDNNHandler<T> handler(
algo, onednn_engine, dev_ctx.GetPlace(), &x, out);
auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto resampling_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
resampling_prim->execute(astream, args);
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
template <typename T, typename Context>
void BilinearInterpKernel(
const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
output);
}
template <typename T, typename Context>
void NearestInterpKernel(
const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& out_size,
const paddle::optional<std::vector<const DenseTensor*>>& size_tensor,
const paddle::optional<DenseTensor>& scale_tensor,
const std::string& data_layout,
int out_d,
int out_h,
int out_w,
const std::vector<float>& scale,
const std::string& interp_method,
bool align_corners,
int align_mode,
DenseTensor* output) {
InterpolateKernel<T, Context>(ctx,
x,
out_size,
size_tensor,
scale_tensor,
data_layout,
out_d,
out_h,
out_w,
scale,
interp_method,
output);
}
} // namespace phi
PD_REGISTER_KERNEL(
bilinear_interp, OneDNN, ALL_LAYOUT, phi::BilinearInterpKernel, float) {}
PD_REGISTER_KERNEL(nearest_interp,
OneDNN,
ALL_LAYOUT,
phi::NearestInterpKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册