提交 ad8a9cb8 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Pool & LRN Grad Ops NHWC support (#21747)

上级 e1d666fb
......@@ -185,10 +185,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
}
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
if (out_layout == DataLayout::kNHWC) {
std::rotate(out_tz.begin() + 1, out_tz.begin() + 2, out_tz.end());
out->Resize(framework::make_ddim(out_tz));
}
platform::MatchShapeToLayout(out, in_layout, out_layout);
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(MKLDNNMemoryFormat::undef);
......
......@@ -58,13 +58,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if (lin == DataLayout::kNHWC) {
auto nchw_dims = paddle::framework::vectorize<int>(out.dims());
std::rotate(nchw_dims.begin() + 1, nchw_dims.end() - 1,
nchw_dims.end());
out.Resize(framework::make_ddim(nchw_dims));
paddle::platform::set_cur_paddle_data_layout(lin);
}
platform::MatchShapeToLayout(&out, lin, lout);
paddle::platform::set_cur_paddle_data_layout(lin);
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
} else {
......
......@@ -39,6 +39,5 @@ void TransformData(const OpKernelType &expected_kernel_type,
*/
void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
Variable *out_var);
} // namespace framework
} // namespace paddle
......@@ -33,6 +33,10 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(benchmark);
DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
......@@ -1102,11 +1106,8 @@ Scope* OperatorWithKernel::PrepareData(
}
for (auto& var_name_item : Inputs()) {
if (no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0) {
VLOG(7) << "Skip scanning input " << var_name_item.first
<< " in Operator " << type_;
continue;
}
bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0;
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
......@@ -1120,6 +1121,44 @@ Scope* OperatorWithKernel::PrepareData(
}
auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
// When no_buffer_ins then checking of Tensor::holder_ is
// not a thread safe. And for infershape scenario checks
// to be omitted are not really needed
if (should_skip_input == true) {
#ifdef PADDLE_WITH_MKLDNN
// Var without buffer may be needed
// for some situation like InferShape().
// In this situation We cannot skip Var analysis, as
// MKL-DNN shape of Var may differ from kNHWC Var
// In such situation corressponding resized Var
// has to be created and registered
if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
(var->IsType<LoDTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(paddle::platform::get_cur_paddle_data_layout() ==
DataLayout::kNHWC)) {
// Mixed execution : MKL-DNN and GPU is not supported!
if (!new_scope) {
new_scope = &scope.NewScope();
}
auto* trans_var = new_scope->Var(var_name);
input_vars[i] = trans_var;
auto out = trans_var->GetMutable<LoDTensor>();
out->Resize(tensor_in->dims());
platform::MatchShapeToLayout(out, tensor_in->layout(),
DataLayout::kNHWC);
VLOG(7) << "Created reshaped dummy input based on MKL-DNN Tensor , "
"but kNHWC layout"
<< var_name_item.first << " in Operator " << type_;
} else {
VLOG(7) << "Skip scanning input " << var_name_item.first
<< " in Operator " << type_;
}
#endif
continue;
}
if (!tensor_in->IsInitialized()) {
continue;
}
......@@ -1143,14 +1182,17 @@ Scope* OperatorWithKernel::PrepareData(
// In the inference scenerio, the scopes will be reused across the
// batches, so the `new_scope` here will result in GPU memroy explosion
// over the running of operators.
// We use a thread_local cache to fix that issue, the key in the cache is
// We use a thread_local cache to fix that issue, the key in the cache
// is
// the combination of the `scope` argument, from_kernel_type,
// target_kernel_type.
// Have a discussion with @Superjomn or the inference developers if some
// changes on this logic for this macro might not tested on the other
// scenerios.
// If this op is not called by an Executor or ParallelExecutor, it should
// called by a NaiveExecutor, the NaiveExecutor will cache the scopes and
// If this op is not called by an Executor or ParallelExecutor, it
// should
// called by a NaiveExecutor, the NaiveExecutor will cache the scopes
// and
// variables, that behavior a lot different.
//
// To solve issue #15032, have a discussion with @Luotao for cpu
......@@ -1174,15 +1216,14 @@ Scope* OperatorWithKernel::PrepareData(
// we will create a new cpu tensor in new scope.
// However, if enable_cache_runtime_context_, we get the cpu tensor each
// time, not the gpu tensor.
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()` in
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()`
// in
// RunImpl().
if (enable_cache_runtime_context_) {
pre_scope_ = nullptr;
}
auto* trans_var = new_scope->Var(var_name);
input_vars[i] = trans_var;
Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
SetTensorToVariable(*var, out, trans_var);
......
......@@ -334,12 +334,6 @@ class LRNOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -348,6 +342,28 @@ class LRNOpGrad : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for lrn
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
template <typename T>
......
......@@ -202,12 +202,6 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -222,6 +216,24 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
library_);
}
framework::OpKernelType PoolOpGrad::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::StringToDataLayout(data_format));
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void Pool2dOpMaker::Make() {
AddInput(
"X",
......
......@@ -38,7 +38,7 @@ class PoolOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const;
const framework::OpKernelType& expected_kernel_type) const override;
};
class PoolOpGrad : public framework::OperatorWithKernel {
......@@ -50,6 +50,10 @@ class PoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -71,6 +71,29 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
return tf_pd<Type>(desc, e, p);
}
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
framework::DataLayout from,
framework::DataLayout to) {
switch (from) {
case framework::DataLayout::kMKLDNN:
if (to == framework::DataLayout::kNHWC) {
auto dims = framework::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
tensor_in->Resize(framework::make_ddim(dims));
}
break;
case framework::DataLayout::kNHWC:
if (to == framework::DataLayout::kMKLDNN) {
auto dims = framework::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
tensor_in->Resize(framework::make_ddim(dims));
}
break;
default:
break;
}
}
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type,
MKLDNNMemoryFormat format) {
......
......@@ -59,11 +59,6 @@ class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
def init_test_case(self):
self.data_format = 'NHWC'
#TODO(jczaja): Add grad support
def test_check_grad_normal(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
self.check_grad(['X'], 'Out')
if __name__ == "__main__":
unittest.main()
......@@ -157,11 +157,6 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
def init_shape(self):
self.shape = [2, 7, 7, 3]
#TODO(jczaja): Add Grad NHWC support
def test_check_grad(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
super(TestAsymPadValidNHWC, self).test_check_grad()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册