提交 2d955275 编写于 作者: T Tomasz Patejko

Removing WITHIN_CHANNEL algorithm for lrn. CPU lrn operator works only with ACROSS_CHANNELS

上级 c51c4462
...@@ -22,18 +22,6 @@ namespace operators { ...@@ -22,18 +22,6 @@ namespace operators {
using paddle::framework::Tensor; using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
namespace {
mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) {
mkldnn::algorithm algorithm = mkldnn::lrn_across_channels;
std::string algorithm_str = ctx.Attr<std::string>("algorithm");
if (algorithm_str == "WITHIN_CHANNEL") {
algorithm = mkldnn::lrn_within_channel;
}
return algorithm;
}
} // namespace
template <typename T> template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float beta = ctx.Attr<float>("beta"); const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
auto algorithm = LRNAlgorithm(ctx);
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k); e_mid = e_mid.constant(k);
...@@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = paddle::platform::MKLDNNMemDesc( auto dst_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto forward_desc = mkldnn::lrn_forward::desc{ auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k}; mkldnn::lrn_across_channels,
src_md,
n,
alpha,
beta,
k};
auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>( auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
forward_desc, mkldnn_engine); forward_desc, mkldnn_engine);
...@@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine}, auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine},
static_cast<void*>(x_grad_data)}; static_cast<void*>(x_grad_data)};
auto algorithm = LRNAlgorithm(ctx);
auto backward_desc = mkldnn::lrn_backward::desc{ auto backward_desc = mkldnn::lrn_backward::desc{
algorithm, src_md, diff_src_md, n, alpha, beta, k}; mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k};
auto forward_pd = dev_ctx.GetBlob(key_pd); auto forward_pd = dev_ctx.GetBlob(key_pd);
......
...@@ -120,24 +120,24 @@ template struct LRNGradFunctor<platform::CPUDeviceContext, float>; ...@@ -120,24 +120,24 @@ template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>; template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
namespace { namespace {
framework::OpKernelType GetExpectedLRNKernel( framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
} }
#endif #endif
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_); layout_, library_);
}
} }
} // namespace
class LRNOp : public framework::OperatorWithKernel { class LRNOp : public framework::OperatorWithKernel {
public: public:
...@@ -214,11 +214,6 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -214,11 +214,6 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("AnyLayout");
AddAttr<std::string>("algorithm",
"(string default ACROSS_CHANNELS"
"An optional string: \"ACROSS_CHANNELS\", "
"\"WITHIN_CHANNEL\". Used by MKLDNN library")
.SetDefault("ACROSS_CHANNELS");
AddComment(R"DOC( AddComment(R"DOC(
Local Response Normalization Operator. Local Response Normalization Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册