提交 2d955275 编写于 作者: T Tomasz Patejko

Removing WITHIN_CHANNEL algorithm for lrn. CPU lrn operator works only with ACROSS_CHANNELS

上级 c51c4462
......@@ -22,18 +22,6 @@ namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
namespace {
mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) {
mkldnn::algorithm algorithm = mkldnn::lrn_across_channels;
std::string algorithm_str = ctx.Attr<std::string>("algorithm");
if (algorithm_str == "WITHIN_CHANNEL") {
algorithm = mkldnn::lrn_within_channel;
}
return algorithm;
}
} // namespace
template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
auto algorithm = LRNAlgorithm(ctx);
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
......@@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto forward_desc = mkldnn::lrn_forward::desc{
mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k};
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
src_md,
n,
alpha,
beta,
k};
auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
forward_desc, mkldnn_engine);
......@@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine},
static_cast<void*>(x_grad_data)};
auto algorithm = LRNAlgorithm(ctx);
auto backward_desc = mkldnn::lrn_backward::desc{
algorithm, src_md, diff_src_md, n, alpha, beta, k};
mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k};
auto forward_pd = dev_ctx.GetBlob(key_pd);
......
......@@ -120,24 +120,24 @@ template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
namespace {
framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
}
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
}
} // namespace
class LRNOp : public framework::OperatorWithKernel {
public:
......@@ -214,11 +214,6 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<std::string>("algorithm",
"(string default ACROSS_CHANNELS"
"An optional string: \"ACROSS_CHANNELS\", "
"\"WITHIN_CHANNEL\". Used by MKLDNN library")
.SetDefault("ACROSS_CHANNELS");
AddComment(R"DOC(
Local Response Normalization Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册