提交 cd43c444 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] LRN and Pool2d (FWD) NHWC support (#21375)

上级 add62acf
......@@ -127,13 +127,17 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN");
innerTransDataLayoutFromMKLDNN(in_layout, out_layout, in, out, place);
#ifdef PADDLE_WITH_MKLDNN
innerTransDataLayoutFromMKLDNN(in_layout,
paddle::platform::get_cur_paddle_data_layout(),
in, out, place);
#endif
}
#ifdef PADDLE_WITH_MKLDNN
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place) {
#ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::format_undef,
platform::errors::InvalidArgument(
"Input tensor format is invalid. Input tensor should "
......@@ -185,11 +189,17 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
} else {
out->ShareDataWith(in);
}
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
if (out_layout == DataLayout::kNHWC) {
std::rotate(out_tz.begin() + 1, out_tz.begin() + 2, out_tz.end());
out->Resize(framework::make_ddim(out_tz));
}
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::format_undef);
#endif
}
#endif
} // namespace framework
} // namespace paddle
......@@ -66,16 +66,15 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
return MKLDNNDataType::data_undef;
}
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
#endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out);
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#ifdef PADDLE_WITH_MKLDNN
#include <algorithm>
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -54,8 +55,16 @@ void TransformData(const OpKernelType &expected_kernel_type,
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if (lin == DataLayout::kNHWC) {
auto nchw_dims = paddle::framework::vectorize<int>(out.dims());
std::rotate(nchw_dims.begin() + 1, nchw_dims.end() - 1,
nchw_dims.end());
out.Resize(framework::make_ddim(nchw_dims));
paddle::platform::set_cur_paddle_data_layout(lin);
}
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
#endif
......
......@@ -103,6 +103,7 @@ Executor::~Executor() {
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::set_cur_paddle_data_layout(paddle::framework::DataLayout::kNCHW);
}
#endif
}
......
......@@ -470,6 +470,11 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels;
}
bool IsMKLDNNType() const {
return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
framework::DataLayout::kMKLDNN));
}
bool SupportGPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
......
......@@ -56,16 +56,20 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
if (src_item.layout() == framework::DataLayout::kMKLDNN) {
framework::Tensor out;
framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), framework::DataLayout::kNCHW, src_item, &out,
platform::CPUPlace());
src_item.layout(), paddle::platform::get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), &dst_item);
} else {
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
}
#else
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
#endif
} else {
// Not copy, if the src tensor is empty.
dst_item.clear();
......
......@@ -193,12 +193,6 @@ class LRNOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -207,6 +201,28 @@ class LRNOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
......
......@@ -102,6 +102,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory));
}
};
......
......@@ -62,6 +62,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::lrn_forward> lrn_p;
if (is_test == false) {
workspace_memory = handler.AcquireWorkspaceMemory(mid);
mid->set_layout(framework::DataLayout::kMKLDNN);
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory,
*dst_memory);
} else {
......
......@@ -88,7 +88,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
ksize.size(), strides.size(), framework::make_ddim(ksize),
framework::make_ddim(strides));
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
// update paddings if "SAME" or global_pooling
framework::DDim data_dims;
......@@ -146,12 +149,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -162,6 +159,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
layout_, library_);
}
framework::OpKernelType PoolOp::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
......
......@@ -35,6 +35,10 @@ class PoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const;
};
class PoolOpGrad : public framework::OperatorWithKernel {
......
......@@ -397,6 +397,10 @@ thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local int cur_input_shape_cache_capacity = 1;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
thread_local paddle::framework::DataLayout cur_paddle_data_layout =
paddle::framework::DataLayout::kNCHW;
} // namespace
void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; }
......@@ -408,6 +412,14 @@ void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
void set_cur_paddle_data_layout(framework::DataLayout dl) {
cur_paddle_data_layout = dl;
}
framework::DataLayout get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
......
......@@ -30,6 +30,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN
#include "mkldnn.hpp"
#include "paddle/fluid/framework/data_layout.h"
#endif
#include <map>
......@@ -290,6 +291,8 @@ void set_cur_mkldnn_session_id(size_t);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(framework::DataLayout);
framework::DataLayout get_cur_paddle_data_layout(void);
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
......
......@@ -502,7 +502,7 @@ class LRNMKLDNNHandler
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
framework::Tensor* workspace) {
T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->dst_primitive_desc().get_size());
this->place_, this->fwd_pd_->workspace_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p");
}
......
......@@ -55,16 +55,11 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
self.assertRaises(AttributeError, check_raise_is_test)
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
def init_test_case(self):
self.data_format = 'NHWC'
def test_check_output(self):
pass
# Grad tests both FWD and BWD ops kernels creation
#TODO(jczaja): Add grad support
def test_check_grad_normal(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
......
......@@ -141,9 +141,6 @@ class TestAsymPadValid(TestAsymPad):
self.padding_algorithm = "VALID"
# Designed to Fail
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestAsymPadValidNHWC(TestAsymPadValid):
def init_data_format(self):
self.data_format = "NHWC"
......@@ -151,12 +148,7 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
def init_shape(self):
self.shape = [2, 7, 7, 3]
def test_check_output(self):
pass
# Grad tests both FWD and BWD ops kernels creation
# GetExpectedKernelType should throw an exception on lack of support
# to NHWC inputs in pool mkldnn kernel
#TODO(jczaja): Add Grad NHWC support
def test_check_grad(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
super(TestAsymPadValidNHWC, self).test_check_grad()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册