提交 cd43c444 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] LRN and Pool2d (FWD) NHWC support (#21375)

上级 add62acf
...@@ -127,13 +127,17 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -127,13 +127,17 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"); "non-MKLDNN");
innerTransDataLayoutFromMKLDNN(in_layout, out_layout, in, out, place); #ifdef PADDLE_WITH_MKLDNN
innerTransDataLayoutFromMKLDNN(in_layout,
paddle::platform::get_cur_paddle_data_layout(),
in, out, place);
#endif
} }
#ifdef PADDLE_WITH_MKLDNN
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out, const Tensor& in, Tensor* out,
platform::Place place) { platform::Place place) {
#ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef, PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input tensor format is invalid. Input tensor should " "Input tensor format is invalid. Input tensor should "
...@@ -185,11 +189,17 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -185,11 +189,17 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
} else { } else {
out->ShareDataWith(in); out->ShareDataWith(in);
} }
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
if (out_layout == DataLayout::kNHWC) {
std::rotate(out_tz.begin() + 1, out_tz.begin() + 2, out_tz.end());
out->Resize(framework::make_ddim(out_tz));
}
out->set_layout(out_layout); out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel // reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::format_undef); out->set_format(MKLDNNMemoryFormat::format_undef);
#endif
} }
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -66,16 +66,15 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { ...@@ -66,16 +66,15 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
#endif #endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out); const Tensor& in, Tensor* out);
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const OpKernelType& kernel_type_for_var,
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include <algorithm>
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -54,8 +55,16 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -54,8 +55,16 @@ void TransformData(const OpKernelType &expected_kernel_type,
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(), auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin)); ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor); out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if (lin == DataLayout::kNHWC) {
auto nchw_dims = paddle::framework::vectorize<int>(out.dims());
std::rotate(nchw_dims.begin() + 1, nchw_dims.end() - 1,
nchw_dims.end());
out.Resize(framework::make_ddim(nchw_dims));
paddle::platform::set_cur_paddle_data_layout(lin);
}
out.set_layout(DataLayout::kMKLDNN); out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format); out.set_format(out_format);
#endif #endif
......
...@@ -103,6 +103,7 @@ Executor::~Executor() { ...@@ -103,6 +103,7 @@ Executor::~Executor() {
platform::MKLDNNDeviceContext* dev_ctx = platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_); (platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap(); dev_ctx->ResetBlobMap();
platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW);
} }
#endif #endif
} }
......
...@@ -470,6 +470,11 @@ class OperatorWithKernel : public OperatorBase { ...@@ -470,6 +470,11 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels; return g_all_op_kernels;
} }
bool IsMKLDNNType() const {
return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
framework::DataLayout::kMKLDNN));
}
bool SupportGPU() const override { bool SupportGPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(), return std::any_of(op_kernels.begin(), op_kernels.end(),
......
...@@ -56,16 +56,20 @@ class FetchOp : public framework::OperatorBase { ...@@ -56,16 +56,20 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
if (src_item.IsInitialized() && src_item.numel() > 0) { if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle // Conversion from MKL-DNN to Paddle
if (src_item.layout() == framework::DataLayout::kMKLDNN) { if (src_item.layout() == framework::DataLayout::kMKLDNN) {
framework::Tensor out; framework::Tensor out;
framework::innerTransDataLayoutFromMKLDNN( framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), framework::DataLayout::kNCHW, src_item, &out, src_item.layout(), paddle::platform::get_cur_paddle_data_layout(),
platform::CPUPlace()); src_item, &out, platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), &dst_item); TensorCopySync(out, platform::CPUPlace(), &dst_item);
} else { } else {
TensorCopySync(src_item, platform::CPUPlace(), &dst_item); TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
} }
#else
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
#endif
} else { } else {
// Not copy, if the src tensor is empty. // Not copy, if the src tensor is empty.
dst_item.clear(); dst_item.clear();
......
...@@ -193,12 +193,6 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -193,12 +193,6 @@ class LRNOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -207,6 +201,28 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -207,6 +201,28 @@ class LRNOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_); layout_, library_);
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}; };
template <typename T> template <typename T>
......
...@@ -102,6 +102,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -102,6 +102,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory)); output->set_format(GetMKLDNNFormat(*dst_memory));
} }
}; };
......
...@@ -62,6 +62,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -62,6 +62,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::lrn_forward> lrn_p; std::shared_ptr<mkldnn::lrn_forward> lrn_p;
if (is_test == false) { if (is_test == false) {
workspace_memory = handler.AcquireWorkspaceMemory(mid); workspace_memory = handler.AcquireWorkspaceMemory(mid);
mid->set_layout(framework::DataLayout::kMKLDNN);
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory, lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory,
*dst_memory); *dst_memory);
} else { } else {
......
...@@ -88,7 +88,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -88,7 +88,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
ksize.size(), strides.size(), framework::make_ddim(ksize), ksize.size(), strides.size(), framework::make_ddim(ksize),
framework::make_ddim(strides)); framework::make_ddim(strides));
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); // MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
// update paddings if "SAME" or global_pooling // update paddings if "SAME" or global_pooling
framework::DDim data_dims; framework::DDim data_dims;
...@@ -146,12 +149,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -146,12 +149,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -162,6 +159,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -162,6 +159,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
layout_, library_); layout_, library_);
} }
framework::OpKernelType PoolOp::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
......
...@@ -35,6 +35,10 @@ class PoolOp : public framework::OperatorWithKernel { ...@@ -35,6 +35,10 @@ class PoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const;
}; };
class PoolOpGrad : public framework::OperatorWithKernel { class PoolOpGrad : public framework::OperatorWithKernel {
......
...@@ -397,6 +397,10 @@ thread_local std::string cur_input_shape_str = ""; ...@@ -397,6 +397,10 @@ thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN. // the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape. // Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1; thread_local int cur_input_shape_cache_capacity = 1;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
thread_local paddle::framework::DataLayout cur_paddle_data_layout =
paddle::framework::DataLayout::kNCHW;
} // namespace } // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
...@@ -408,6 +412,14 @@ void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) { ...@@ -408,6 +412,14 @@ void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity; cur_input_shape_cache_capacity = input_shape_cache_capacity;
} }
void set_cur_paddle_data_layout(framework::DataLayout dl) {
cur_paddle_data_layout = dl;
}
framework::DataLayout get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
size_t MKLDNNDeviceContext::GetShapeBlobSize() const { size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
......
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "paddle/fluid/framework/data_layout.h"
#endif #endif
#include <map> #include <map>
...@@ -290,6 +291,8 @@ void set_cur_mkldnn_session_id(size_t); ...@@ -290,6 +291,8 @@ void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void); size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str); void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(framework::DataLayout);
framework::DataLayout get_cur_paddle_data_layout(void);
class MKLDNNDeviceContext : public CPUDeviceContext { class MKLDNNDeviceContext : public CPUDeviceContext {
public: public:
......
...@@ -502,7 +502,7 @@ class LRNMKLDNNHandler ...@@ -502,7 +502,7 @@ class LRNMKLDNNHandler
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory( std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
framework::Tensor* workspace) { framework::Tensor* workspace) {
T* ptr = workspace->mutable_data<T>( T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->dst_primitive_desc().get_size()); this->place_, this->fwd_pd_->workspace_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p"); this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p");
} }
......
...@@ -55,16 +55,11 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): ...@@ -55,16 +55,11 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
self.assertRaises(AttributeError, check_raise_is_test) self.assertRaises(AttributeError, check_raise_is_test)
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp): class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
def init_test_case(self): def init_test_case(self):
self.data_format = 'NHWC' self.data_format = 'NHWC'
def test_check_output(self): #TODO(jczaja): Add grad support
pass
# Grad tests both FWD and BWD ops kernels creation
def test_check_grad_normal(self): def test_check_grad_normal(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet): with self.assertRaises(fluid.core_avx.EnforceNotMet):
self.check_grad(['X'], 'Out', max_relative_error=0.01) self.check_grad(['X'], 'Out', max_relative_error=0.01)
......
...@@ -141,9 +141,6 @@ class TestAsymPadValid(TestAsymPad): ...@@ -141,9 +141,6 @@ class TestAsymPadValid(TestAsymPad):
self.padding_algorithm = "VALID" self.padding_algorithm = "VALID"
# Designed to Fail
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestAsymPadValidNHWC(TestAsymPadValid): class TestAsymPadValidNHWC(TestAsymPadValid):
def init_data_format(self): def init_data_format(self):
self.data_format = "NHWC" self.data_format = "NHWC"
...@@ -151,12 +148,7 @@ class TestAsymPadValidNHWC(TestAsymPadValid): ...@@ -151,12 +148,7 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
def init_shape(self): def init_shape(self):
self.shape = [2, 7, 7, 3] self.shape = [2, 7, 7, 3]
def test_check_output(self): #TODO(jczaja): Add Grad NHWC support
pass
# Grad tests both FWD and BWD ops kernels creation
# GetExpectedKernelType should throw an exception on lack of support
# to NHWC inputs in pool mkldnn kernel
def test_check_grad(self): def test_check_grad(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet): with self.assertRaises(fluid.core_avx.EnforceNotMet):
super(TestAsymPadValidNHWC, self).test_check_grad() super(TestAsymPadValidNHWC, self).test_check_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册