提交 b0b27ff6 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Conv grad and Batch Norm grad NHWC support (#22088)

上级 1ce6ab9c
...@@ -1215,10 +1215,8 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1215,10 +1215,8 @@ Scope* OperatorWithKernel::PrepareData(
// The reason is that if a gpu tensor is the input of a cpu kernel, // The reason is that if a gpu tensor is the input of a cpu kernel,
// we will create a new cpu tensor in new scope. // we will create a new cpu tensor in new scope.
// However, if enable_cache_runtime_context_, we get the cpu tensor each // However, if enable_cache_runtime_context_, we get the cpu tensor each
// time, not the gpu tensor. // time, not the gpu tensor. Thus, we set pre_scope_ = nullptr
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()` // to trigger `new RuntimeContext()` in RunImpl().
// in
// RunImpl().
if (enable_cache_runtime_context_) { if (enable_cache_runtime_context_) {
pre_scope_ = nullptr; pre_scope_ = nullptr;
} }
......
...@@ -186,9 +186,8 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar( ...@@ -186,9 +186,8 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) { if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType( return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.data_type_, tensor.place(), tensor.place(), dl);
framework::StringToDataLayout(data_layout));
} }
} }
#endif #endif
...@@ -465,8 +464,11 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -465,8 +464,11 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough // has_scale_grad == has_bias_grad, judge has_scale_grad is enough
...@@ -499,12 +501,6 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -499,12 +501,6 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_layout = ctx.Attr<std::string>("data_layout");
PADDLE_ENFORCE_NE(
data_layout, "NHWC",
platform::errors::Unimplemented(
"Batch Norm MKLDNN grad does not support NHWC data format yet"));
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
...@@ -515,6 +511,31 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -515,6 +511,31 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
library); library);
} }
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_layout = ar.Get<std::string>("data_layout");
auto dl = framework::StringToDataLayout(data_layout);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
template <typename T> template <typename T>
class BatchNormGradKernel<platform::CPUDeviceContext, T> class BatchNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
......
...@@ -148,6 +148,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -148,6 +148,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
}; };
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -60,8 +60,13 @@ class FetchOp : public framework::OperatorBase { ...@@ -60,8 +60,13 @@ class FetchOp : public framework::OperatorBase {
// Conversion from MKL-DNN to Paddle // Conversion from MKL-DNN to Paddle
if (src_item.layout() == framework::DataLayout::kMKLDNN) { if (src_item.layout() == framework::DataLayout::kMKLDNN) {
framework::Tensor out; framework::Tensor out;
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
framework::innerTransDataLayoutFromMKLDNN( framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), paddle::platform::get_cur_paddle_data_layout(), src_item.layout(),
fetch_var_name == framework::GradVarName("Filter")
? framework::DataLayout::kNCHW
: paddle::platform::get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace()); src_item, &out, platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), &dst_item); TensorCopySync(out, platform::CPUPlace(), &dst_item);
} else { } else {
......
...@@ -208,9 +208,8 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar( ...@@ -208,9 +208,8 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) { if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType( return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.data_type_, tensor.place(), tensor.place(), dl);
framework::StringToDataLayout(data_format));
} }
} }
#endif #endif
...@@ -554,16 +553,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -554,16 +553,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format"); const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN grad does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN Grad does not support NDHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32; customized_type_value = kConvMKLDNNFP32;
...@@ -591,6 +581,32 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -591,6 +581,32 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return type; return type;
} }
framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "Input") ||
(var_name == framework::GradVarName("Output"))) &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
template <typename T> template <typename T>
class Conv2DGradMaker : public framework::SingleGradOpMaker<T> { class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
......
...@@ -272,6 +272,10 @@ class ConvOpGrad : public framework::OperatorWithKernel { ...@@ -272,6 +272,10 @@ class ConvOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
}; };
class ConvOpDoubleGrad : public framework::OperatorWithKernel { class ConvOpDoubleGrad : public framework::OperatorWithKernel {
......
...@@ -74,6 +74,11 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p, ...@@ -74,6 +74,11 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
inline void MatchShapeToLayout(framework::Tensor* tensor_in, inline void MatchShapeToLayout(framework::Tensor* tensor_in,
framework::DataLayout from, framework::DataLayout from,
framework::DataLayout to) { framework::DataLayout to) {
// Shape changing makes sense for 3+ dims Tensors
if (tensor_in->dims().size() < 3) {
return;
}
switch (from) { switch (from) {
case framework::DataLayout::kMKLDNN: case framework::DataLayout::kMKLDNN:
if (to == framework::DataLayout::kNHWC) { if (to == framework::DataLayout::kNHWC) {
......
...@@ -34,6 +34,10 @@ class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining): ...@@ -34,6 +34,10 @@ class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance, def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout): epsilon, momentum, shape, data_layout):
if data_layout != "NCHW" and data_layout != "NHWC":
raise ValueError("Unknown data order.")
# run forward # run forward
y, saved_mean, saved_variance = _reference_training( y, saved_mean, saved_variance = _reference_training(
x, scale, bias, epsilon, data_layout) x, scale, bias, epsilon, data_layout)
...@@ -46,6 +50,12 @@ class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining): ...@@ -46,6 +50,12 @@ class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
class TestMKLDNNBatchNormOpTraining_NHWC(TestMKLDNNBatchNormOpTraining):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_formats = ["NHWC"]
class TestMKLDNNBatchNormOpExistedPrimitives(TestMKLDNNBatchNormOpTraining): class TestMKLDNNBatchNormOpExistedPrimitives(TestMKLDNNBatchNormOpTraining):
def init_test_case(self): def init_test_case(self):
TestMKLDNNBatchNormOpTraining.init_test_case(self) TestMKLDNNBatchNormOpTraining.init_test_case(self)
......
...@@ -208,18 +208,6 @@ class TestConv2dOp_Valid_NHWC_MKLDNN(TestConv2dOp_Valid_MKLDNN): ...@@ -208,18 +208,6 @@ class TestConv2dOp_Valid_NHWC_MKLDNN(TestConv2dOp_Valid_MKLDNN):
N, C, H, W = self.input_size N, C, H, W = self.input_size
self.input_size = [N, H, W, C] self.input_size = [N, H, W, C]
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad(self):
pass
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad_no_filter(self):
pass
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad_no_input(self):
pass
class TestConv2dOp_Same_NHWC_MKLDNN(TestConv2dOp_Valid_NHWC_MKLDNN): class TestConv2dOp_Same_NHWC_MKLDNN(TestConv2dOp_Valid_NHWC_MKLDNN):
def init_paddings(self): def init_paddings(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册