未验证 提交 1a929c31 编写于 作者: P piotrekobi 提交者: GitHub

[PHI] Migrate cast, clip+grad and pool+grad oneDNN kernels (#45775)

* gaussian random

* mkldnn to onednn renaming

* fix merge conflicts

* remove fluid code

* onednn renaming

* Move classes from mkldnn_reuse.h to onednn_reuse.h

* Migrate pool+grad, clip+grad and cast oneDNN kernels to PHI

* Refactor grad kernels into separate files

* Fix CI failures

* Fix Codestyle

* Implement reviewer suggestions

* Add new lines after includes for readability
Co-authored-by: NSilv3S <slawomir.siwek@intel.com>
上级 23998b75
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
template <typename T>
class CastMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");
auto x_paddle_type = framework::proto::VarType::Type(in_dtype);
auto out_paddle_type = framework::proto::VarType::Type(out_dtype);
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
dnnl::memory::data_type out_type =
framework::ToMKLDNNDataType(out_paddle_type);
auto x_tz = phi::vectorize(x->dims());
platform::ReorderMKLDNNHandler reorder_handler(x_tz,
x_paddle_type,
x_type,
out_paddle_type,
out_type,
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_mem_desc(reorder_dst_memory_p->get_desc());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(cast,
MKLDNN,
paddle::platform::CPUPlace,
ops::CastMKLDNNKernel<float>,
ops::CastMKLDNNKernel<paddle::platform::bfloat16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace {
using paddle::framework::Tensor;
template <typename T>
class ClipMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const paddle::framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
paddle::platform::ActivationMKLDNNHandler<T> handler(
dnnl::algorithm::eltwise_clip_v2,
ctx,
mkldnn_engine,
ctx.GetPlace(),
x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_p->execute(
astream,
{{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
};
template <typename T>
class ClipGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx);
}
void RunKernel(const paddle::framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* dx = ctx.Output<Tensor>(paddle::framework::GradVarName("X"));
auto* dout = ctx.Input<Tensor>(paddle::framework::GradVarName("Out"));
paddle::platform::ActivationMKLDNNHandler<T> handler(
dnnl::algorithm::eltwise_clip_v2,
ctx,
mkldnn_engine,
ctx.GetPlace(),
x,
dout);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
dx->set_mem_desc(diff_dst_memory_p->get_desc());
}
};
} // anonymous namespace
REGISTER_OP_KERNEL(clip,
MKLDNN,
paddle::platform::CPUPlace,
ClipMKLDNNKernel<float>,
ClipMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(clip_grad,
MKLDNN,
paddle::platform::CPUPlace,
ClipGradMKLDNNKernel<float>,
ClipGradMKLDNNKernel<paddle::platform::bfloat16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/pooling.h"
namespace paddle {
namespace operators {
using dnnl::memory;
using dnnl::pooling_backward;
using dnnl::pooling_forward;
using dnnl::primitive;
using dnnl::reorder;
using dnnl::stream;
using framework::DataLayout;
using framework::Tensor;
using platform::to_void_cast;
template <typename T>
class PoolingMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T,
dnnl::pooling_forward,
dnnl::pooling_backward> {
public:
PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const dnnl::engine mkldnn_engine,
const Tensor* input,
Tensor* output)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::pooling_forward,
dnnl::pooling_backward>(
mkldnn_engine, ctx.GetPlace()) {
const std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
const bool global_pooling = ctx.Attr<bool>("global_pooling");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
// Only 2D pooling is supported now
PADDLE_ENFORCE_EQ(
ksize.size(),
2,
platform::errors::InvalidArgument(
"The ksize must be 2D, i.e. 2D pooling, but received %dD.",
ksize.size()));
PADDLE_ENFORCE_EQ(
pooling_type == "max" || pooling_type == "avg",
true,
platform::errors::InvalidArgument(
"The pooling_type must be 'max' or 'avg', but received %s.",
pooling_type));
PADDLE_ENFORCE_EQ(
input->dims().size(),
4,
platform::errors::InvalidArgument(
"Input dim must be with 4, i.e. NCHW, but received %d.",
input->dims().size()));
const auto input_dims = input->dims();
framework::DDim data_dims =
phi::slice_ddim(input_dims, 2, input_dims.size());
if (global_pooling) {
phi::funcs::UpdateKernelSize(&ksize, data_dims);
}
phi::funcs::UpdatePadding(&paddings,
global_pooling,
0,
padding_algorithm,
data_dims,
strides,
ksize);
const auto is_test = ctx.Attr<bool>("is_test");
const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
const auto exclude_padding = ctx.Attr<bool>("exclusive");
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
const auto dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype()));
const auto src_tz = phi::vectorize(input->dims());
const auto dst_tz = phi::vectorize(output->dims());
const auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
if (ceil_mode) {
CorrectOutputSize(
src_tz, dst_tz, ksize, paddings, strides, mkldnn_paddings[1]);
}
ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides);
this->AcquireForwardPrimitiveDescriptor(
is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training,
pooling_type == "max"
? dnnl::algorithm::pooling_max
: (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding),
input->mem_desc(),
dst_md,
strides,
ksize,
mkldnn_paddings[0],
mkldnn_paddings[1]);
}
PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const dnnl::engine mkldnn_engine,
const Tensor* in_x,
const Tensor* out_grad,
Tensor* in_x_grad)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::pooling_forward,
dnnl::pooling_backward>(
mkldnn_engine, ctx.GetPlace()) {
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"),
false,
platform::errors::InvalidArgument(
"is_test attribute should be set to False in training phase."));
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
auto in_x_dims = in_x->dims();
framework::DDim data_dims = phi::slice_ddim(in_x_dims, 2, in_x_dims.size());
if (global_pooling) {
phi::funcs::UpdateKernelSize(&ksize, data_dims);
}
phi::funcs::UpdatePadding(&paddings,
global_pooling,
0,
padding_algorithm,
data_dims,
strides,
ksize);
auto src_tz = phi::vectorize<int64_t>(in_x->dims());
auto diff_src_tz = phi::vectorize<int64_t>(in_x_grad->dims());
auto diff_dst_tz = phi::vectorize<int64_t>(out_grad->dims());
const auto dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(in_x->dtype()));
auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any);
auto diff_src_md = dnnl::memory::desc(
diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
if (ceil_mode) {
CorrectOutputSize(
src_tz, diff_dst_tz, ksize, paddings, strides, mkldnn_paddings[1]);
}
ComputeAdaptivePoolParameters(ctx, diff_src_tz, &ksize, &strides);
const auto exclude_padding = ctx.Attr<bool>("exclusive");
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training,
pooling_type == "max"
? dnnl::algorithm::pooling_max
: (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding),
in_x->mem_desc(),
dst_md,
strides,
ksize,
mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max"
? dnnl::algorithm::pooling_max
: (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding),
diff_src_md,
out_grad->mem_desc(),
strides,
ksize,
mkldnn_paddings[0],
mkldnn_paddings[1]);
}
std::shared_ptr<dnnl::memory> AcquireWorkspaceMemory(
const platform::MKLDNNDeviceContext& dev_ctx,
const std::string& unique_name) {
dnnl::memory::desc workspace_md = this->fwd_pd_->workspace_desc();
// Pooling Workspace has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
std::string workspace_key = platform::CreateKey(dev_ctx,
workspace_md.dims(),
workspace_md.data_type(),
unique_name,
"@wrk");
auto mem_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(workspace_key));
if (mem_p == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
mem_p = std::static_pointer_cast<dnnl::memory>(
dev_ctx.GetBlob(workspace_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<dnnl::memory>(workspace_md, this->engine_);
dev_ctx.SetBlob(workspace_key, mem_p);
}
}
return mem_p;
}
static void ComputeAdaptivePoolParameters(
const paddle::framework::ExecutionContext& ctx,
const std::vector<int64_t>& src_tz,
std::vector<int64_t>* ksize,
std::vector<int64_t>* strides) {
if (ctx.Attr<bool>("adaptive")) {
// https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling
auto IH = static_cast<double>(src_tz[src_tz.size() - 2]);
auto IW = static_cast<double>(src_tz[src_tz.size() - 1]);
auto OH = static_cast<double>(ksize->at(0));
auto OW = static_cast<double>(ksize->at(1));
strides->at(0) =
static_cast<int64_t>(floor((IH * 2.0) / OH) - floor(IH / OH));
strides->at(1) =
static_cast<int64_t>(floor((IW * 2.0) / OW) - floor(IW / OW));
ksize->at(0) =
static_cast<int64_t>(ceil((IH * 2.0) / OH) - floor(IH / OH));
ksize->at(1) =
static_cast<int64_t>(ceil((IW * 2.0) / OW) - floor(IW / OW));
}
}
private:
static inline int ComputeCeiledOutput(int input_size,
int kernel_size,
int padding,
int stride) {
return (input_size - kernel_size + 2 * padding) / stride + 1;
}
static inline void CorrectOutputSize(
const std::vector<int64_t>& src_tz,
const std::vector<int64_t>& dst_tz,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& paddings,
const std::vector<int64_t>& strides,
std::vector<int64_t>& right_bot_padding) { // NOLINT
for (size_t i = 0; i < right_bot_padding.size(); i++) {
int desired_size = ComputeCeiledOutput(
src_tz[i + 2], kernel_size[i], paddings[i], strides[i]);
if (desired_size != dst_tz[i + 2]) {
right_bot_padding[i] += strides[i] - 1;
}
}
}
};
template <typename T>
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Pool must use CPUPlace"));
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PoolingMKLDNNHandler<T> handler(ctx, dev_ctx.GetEngine(), input, output);
auto src_memory = handler.AcquireSrcMemory(input);
auto dst_memory = handler.AcquireDstMemory(output);
auto pool_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if ((ctx.Attr<bool>("is_test") == false) &&
(ctx.Attr<std::string>("pooling_type") == "max")) {
// Training
auto workspace_memory =
handler.AcquireWorkspaceMemory(dev_ctx, ctx.OutputName("Out"));
pool_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_WORKSPACE, *workspace_memory}});
} else {
// Inference
pool_p->execute(
astream, {{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_DST, *dst_memory}});
}
astream.wait();
output->set_mem_desc(dst_memory->get_desc());
}
};
template <typename T>
class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL PoolGrad must use CPUPlace"));
const Tensor* in_x = ctx.Input<Tensor>("X");
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
PoolingMKLDNNHandler<T> handler(
ctx, dev_ctx.GetEngine(), in_x, out_grad, in_x_grad);
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
auto pool_bwd_p = handler.AcquireBackwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ctx.Attr<std::string>("pooling_type") == "max") {
// Max - pooling needs Workspace
auto workspace_memory =
handler.AcquireWorkspaceMemory(dev_ctx, ctx.InputName("Out"));
pool_bwd_p->execute(astream,
{{DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory},
{DNNL_ARG_WORKSPACE, *workspace_memory}});
} else {
// Average Pooling
pool_bwd_p->execute(astream,
{{DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory}});
}
astream.wait();
in_x_grad->set_mem_desc(diff_src_memory->get_desc());
} // Compute()
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d,
MKLDNN,
::paddle::platform::CPUPlace,
ops::PoolMKLDNNOpKernel<float>,
ops::PoolMKLDNNOpKernel<int8_t>,
ops::PoolMKLDNNOpKernel<uint8_t>,
ops::PoolMKLDNNOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(pool2d_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::PoolMKLDNNGradOpKernel<float>,
ops::PoolMKLDNNGradOpKernel<paddle::platform::bfloat16>);
......@@ -28,7 +28,7 @@
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(pool2d);
USE_OP_DEVICE_KERNEL(pool2d, MKLDNN);
PD_DECLARE_KERNEL(pool2d, OneDNN, ALL_LAYOUT);
USE_OP_ITSELF(relu);
PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT);
USE_OP_ITSELF(transpose);
......
......@@ -24,9 +24,12 @@ limitations under the License. */
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/onednn/onednn_helper.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
#include "paddle/phi/kernels/funcs/pooling.h"
namespace phi {
namespace funcs {
......@@ -947,5 +950,313 @@ class ReductionOneDNNHandler
algo, x->mem_desc(), out_md, p, eps);
}
};
template <typename T>
class ClipOneDNNHandler
: public OneDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward> {
public:
ClipOneDNNHandler(const Scalar& min,
const Scalar& max,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x)
: OneDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine, cpu_place) {
float alpha = min.to<float>();
float beta = max.to<float>();
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
dnnl::algorithm::eltwise_clip_v2,
x->mem_desc(),
alpha,
beta);
}
ClipOneDNNHandler(const Scalar& min,
const Scalar& max,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* x,
const DenseTensor* dout)
: OneDNNHandlerNoCachingT<T,
dnnl::eltwise_forward,
dnnl::eltwise_backward>(engine, cpu_place) {
float alpha = min.to<float>();
float beta = max.to<float>();
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
dnnl::algorithm::eltwise_clip_v2,
x->mem_desc(),
alpha,
beta);
this->AcquireBackwardPrimitiveDescriptor(dnnl::algorithm::eltwise_clip_v2,
dout->mem_desc(),
x->mem_desc(),
alpha,
beta);
}
std::shared_ptr<dnnl::memory> AcquireBackwardSrcMemory(
const DenseTensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(),
to_void_cast<T>(input_data));
}
};
template <typename T>
class PoolingOneDNNHandler
: public OneDNNHandlerNoCachingT<T,
dnnl::pooling_forward,
dnnl::pooling_backward> {
public:
PoolingOneDNNHandler(const std::string& pooling_type,
const IntArray& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool global_pooling,
const std::string& padding_algorithm,
bool ceil_mode,
bool exclusive,
bool adaptive,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* input,
DenseTensor* output)
: OneDNNHandlerNoCachingT<T,
dnnl::pooling_forward,
dnnl::pooling_backward>(engine, cpu_place) {
std::vector<int64_t> copied_kernel_size(kernel_size.GetData().begin(),
kernel_size.GetData().end());
std::vector<int64_t> copied_strides(strides.begin(), strides.end());
std::vector<int64_t> copied_paddings(paddings.begin(), paddings.end());
// Only 2D pooling is supported now
PADDLE_ENFORCE_EQ(
copied_kernel_size.size(),
2,
errors::InvalidArgument("The copied_kernel_size must be 2D, i.e. 2D "
"pooling, but received %dD.",
copied_kernel_size.size()));
PADDLE_ENFORCE_EQ(
pooling_type == "max" || pooling_type == "avg",
true,
errors::InvalidArgument(
"The pooling_type must be 'max' or 'avg', but received %s.",
pooling_type));
PADDLE_ENFORCE_EQ(
input->dims().size(),
4,
errors::InvalidArgument(
"Input dim must be with 4, i.e. NCHW, but received %d.",
input->dims().size()));
const auto input_dims = input->dims();
DDim data_dims = slice_ddim(input_dims, 2, input_dims.size());
if (global_pooling) {
UpdateKernelSize<int64_t>(&copied_kernel_size, data_dims);
}
UpdatePadding<int64_t>(&copied_paddings,
global_pooling,
0,
padding_algorithm,
data_dims,
copied_strides,
copied_kernel_size);
auto onednn_paddings = ToOneDNNPadding(copied_paddings);
const auto dt = ToOneDNNDataType(input->dtype());
const auto src_tz = vectorize(input->dims());
const auto dst_tz = vectorize(output->dims());
const auto dst_md = OneDNNMemDesc(dst_tz, dt, OneDNNMemoryFormat::any);
if (ceil_mode) {
CorrectOutputSize(src_tz,
dst_tz,
copied_kernel_size,
copied_paddings,
copied_strides,
onednn_paddings[1]);
}
if (adaptive) {
ComputeAdaptivePoolParameters(
src_tz, &copied_kernel_size, &copied_strides);
}
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training,
pooling_type == "max"
? dnnl::algorithm::pooling_max
: (exclusive ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding),
input->mem_desc(),
dst_md,
copied_strides,
copied_kernel_size,
onednn_paddings[0],
onednn_paddings[1]);
}
PoolingOneDNNHandler(const std::string& pooling_type,
const IntArray& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool global_pooling,
const std::string& padding_algorithm,
bool ceil_mode,
bool exclusive,
bool adaptive,
const dnnl::engine engine,
Place cpu_place,
const DenseTensor* in_x,
const DenseTensor* out_grad,
DenseTensor* in_x_grad)
: OneDNNHandlerNoCachingT<T,
dnnl::pooling_forward,
dnnl::pooling_backward>(engine, cpu_place) {
std::vector<int64_t> copied_kernel_size(kernel_size.GetData().begin(),
kernel_size.GetData().end());
std::vector<int64_t> copied_strides(strides.begin(), strides.end());
std::vector<int64_t> copied_paddings(paddings.begin(), paddings.end());
auto in_x_dims = in_x->dims();
DDim data_dims = slice_ddim(in_x_dims, 2, in_x_dims.size());
if (global_pooling) {
UpdateKernelSize<int64_t>(&copied_kernel_size, data_dims);
}
UpdatePadding<int64_t>(&copied_paddings,
global_pooling,
0,
padding_algorithm,
data_dims,
copied_strides,
copied_kernel_size);
auto src_tz = vectorize<int64_t>(in_x->dims());
auto diff_src_tz = vectorize<int64_t>(in_x_grad->dims());
auto diff_dst_tz = vectorize<int64_t>(out_grad->dims());
const auto dt = ToOneDNNDataType(in_x->dtype());
auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, OneDNNMemoryFormat::any);
auto diff_src_md = dnnl::memory::desc(
diff_src_tz, oneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
auto onednn_paddings = ToOneDNNPadding(copied_paddings);
if (ceil_mode) {
CorrectOutputSize(src_tz,
diff_dst_tz,
copied_kernel_size,
copied_paddings,
copied_strides,
onednn_paddings[1]);
}
if (adaptive) {
ComputeAdaptivePoolParameters(
diff_src_tz, &copied_kernel_size, &copied_strides);
}
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training,
pooling_type == "max"
? dnnl::algorithm::pooling_max
: (exclusive ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding),
in_x->mem_desc(),
dst_md,
copied_strides,
copied_kernel_size,
onednn_paddings[0],
onednn_paddings[1]);
this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max"
? dnnl::algorithm::pooling_max
: (exclusive ? dnnl::algorithm::pooling_avg_exclude_padding
: dnnl::algorithm::pooling_avg_include_padding),
diff_src_md,
out_grad->mem_desc(),
copied_strides,
copied_kernel_size,
onednn_paddings[0],
onednn_paddings[1]);
}
std::shared_ptr<dnnl::memory> AcquireWorkspaceMemory(
const OneDNNContext& dev_ctx, const std::string& unique_name) {
dnnl::memory::desc workspace_md = this->fwd_pd_->workspace_desc();
// Pooling Workspace has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
std::string workspace_key = CreateKey(dev_ctx,
workspace_md.dims(),
workspace_md.data_type(),
unique_name,
"@wrk");
auto mem_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(workspace_key));
if (mem_p == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
mem_p = std::static_pointer_cast<dnnl::memory>(
dev_ctx.GetBlob(workspace_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<dnnl::memory>(workspace_md, this->engine_);
dev_ctx.SetBlob(workspace_key, mem_p);
}
}
return mem_p;
}
static void ComputeAdaptivePoolParameters(const std::vector<int64_t>& src_tz,
std::vector<int64_t>* kernel_size,
std::vector<int64_t>* strides) {
// https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling
auto IH = static_cast<double>(src_tz[src_tz.size() - 2]);
auto IW = static_cast<double>(src_tz[src_tz.size() - 1]);
auto OH = static_cast<double>(kernel_size->at(0));
auto OW = static_cast<double>(kernel_size->at(1));
strides->at(0) =
static_cast<int64_t>(floor((IH * 2.0) / OH) - floor(IH / OH));
strides->at(1) =
static_cast<int64_t>(floor((IW * 2.0) / OW) - floor(IW / OW));
kernel_size->at(0) =
static_cast<int64_t>(ceil((IH * 2.0) / OH) - floor(IH / OH));
kernel_size->at(1) =
static_cast<int64_t>(ceil((IW * 2.0) / OW) - floor(IW / OW));
}
private:
static inline int ComputeCeiledOutput(int input_size,
int kernel_size,
int padding,
int stride) {
return (input_size - kernel_size + 2 * padding) / stride + 1;
}
static inline void CorrectOutputSize(
const std::vector<int64_t>& src_tz,
const std::vector<int64_t>& dst_tz,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& paddings,
const std::vector<int64_t>& strides,
std::vector<int64_t>& right_bot_padding) { // NOLINT
for (size_t i = 0; i < right_bot_padding.size(); i++) {
int desired_size = ComputeCeiledOutput(
src_tz[i + 2], kernel_size[i], paddings[i], strides[i]);
if (desired_size != dst_tz[i + 2]) {
right_bot_padding[i] += strides[i] - 1;
}
}
}
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
DataType in_dtype = x.dtype();
dnnl::memory::data_type in_dnnl_dtype = funcs::ToOneDNNDataType(in_dtype);
dnnl::memory::data_type out_dnnl_dtype = funcs::ToOneDNNDataType(out_dtype);
auto x_tz = phi::vectorize(x.dims());
funcs::ReorderOneDNNHandler reorder_handler(x_tz,
in_dtype,
in_dnnl_dtype,
out_dtype,
out_dnnl_dtype,
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x.mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->set_layout(DataLayout::ONEDNN);
out->set_mem_desc(reorder_dst_memory_p->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(
cast, OneDNN, ALL_LAYOUT, phi::CastKernel, float, phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ClipGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& min,
const Scalar& max,
DenseTensor* x_grad) {
const auto& onednn_engine = dev_ctx.GetEngine();
funcs::ClipOneDNNHandler<T> handler(
min, max, onednn_engine, dev_ctx.GetPlace(), &x, &out_grad);
auto src_memory_p = handler.AcquireBackwardSrcMemory(&x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&out_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(x_grad);
auto activation_backward_p = handler.AcquireBackwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
x_grad->set_mem_desc(diff_dst_memory_p->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(clip_grad,
OneDNN,
ALL_LAYOUT,
phi::ClipGradKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ClipKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& min,
const Scalar& max,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();
funcs::ClipOneDNNHandler<T> handler(
min, max, onednn_engine, dev_ctx.GetPlace(), &x);
auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
activation_p->execute(
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(
clip, OneDNN, ALL_LAYOUT, phi::ClipKernel, float, phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pool_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void Pool2dGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const IntArray& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool ceil_mode,
bool exclusive,
const std::string& data_format,
const std::string& pooling_type,
bool global_pooling,
bool adaptive,
const std::string& padding_algorithm,
DenseTensor* dx) {
funcs::PoolingOneDNNHandler<T> handler(pooling_type,
kernel_size,
strides,
paddings,
global_pooling,
padding_algorithm,
ceil_mode,
exclusive,
adaptive,
dev_ctx.GetEngine(),
dev_ctx.GetPlace(),
&x,
&dout,
dx);
auto diff_dst_memory = handler.AcquireDiffDstMemory(&dout);
auto diff_src_memory = handler.AcquireDiffSrcMemory(dx);
auto pool_bwd_p = handler.AcquireBackwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
if (pooling_type == "max") {
// Max - pooling needs Workspace
auto workspace_memory = handler.AcquireWorkspaceMemory(dev_ctx, "Out");
pool_bwd_p->execute(astream,
{{DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory},
{DNNL_ARG_WORKSPACE, *workspace_memory}});
} else {
// Average Pooling
pool_bwd_p->execute(astream,
{{DNNL_ARG_DIFF_SRC, *diff_src_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory}});
}
astream.wait();
dx->set_mem_desc(diff_src_memory->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(pool2d_grad,
OneDNN,
ALL_LAYOUT,
phi::Pool2dGradKernel,
float,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pool_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void Pool2dKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& kernel_size,
const std::vector<int>& strides,
const std::vector<int>& paddings,
bool ceil_mode,
bool exclusive,
const std::string& data_format,
const std::string& pooling_type,
bool global_pooling,
bool adaptive,
const std::string& padding_algorithm,
DenseTensor* out) {
funcs::PoolingOneDNNHandler<T> handler(pooling_type,
kernel_size,
strides,
paddings,
global_pooling,
padding_algorithm,
ceil_mode,
exclusive,
adaptive,
dev_ctx.GetEngine(),
dev_ctx.GetPlace(),
&x,
out);
auto src_memory = handler.AcquireSrcMemory(&x);
auto dst_memory = handler.AcquireDstMemory(out);
auto pool_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
if (pooling_type == "max") {
// Training
auto workspace_memory = handler.AcquireWorkspaceMemory(dev_ctx, "Out");
pool_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory},
{DNNL_ARG_WORKSPACE, *workspace_memory}});
} else {
// Inference
pool_p->execute(astream,
{{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_DST, *dst_memory}});
}
astream.wait();
out->set_mem_desc(dst_memory->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(pool2d,
OneDNN,
ALL_LAYOUT,
phi::Pool2dKernel,
float,
int8_t,
uint8_t,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册