提交 b0b27ff6 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Conv grad and Batch Norm grad NHWC support (#22088)

上级 1ce6ab9c
......@@ -1215,10 +1215,8 @@ Scope* OperatorWithKernel::PrepareData(
// The reason is that if a gpu tensor is the input of a cpu kernel,
// we will create a new cpu tensor in new scope.
// However, if enable_cache_runtime_context_, we get the cpu tensor each
// time, not the gpu tensor.
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()`
// in
// RunImpl().
// time, not the gpu tensor. Thus, we set pre_scope_ = nullptr
// to trigger `new RuntimeContext()` in RunImpl().
if (enable_cache_runtime_context_) {
pre_scope_ = nullptr;
}
......
......@@ -186,9 +186,8 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(),
framework::StringToDataLayout(data_layout));
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
......@@ -465,8 +464,11 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
const auto x_dims = ctx->GetInputDim("X");
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
......@@ -499,12 +501,6 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_layout = ctx.Attr<std::string>("data_layout");
PADDLE_ENFORCE_NE(
data_layout, "NHWC",
platform::errors::Unimplemented(
"Batch Norm MKLDNN grad does not support NHWC data format yet"));
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......@@ -515,6 +511,31 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
library);
}
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_layout = ar.Get<std::string>("data_layout");
auto dl = framework::StringToDataLayout(data_layout);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
template <typename T>
class BatchNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
......
......@@ -148,6 +148,10 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -60,8 +60,13 @@ class FetchOp : public framework::OperatorBase {
// Conversion from MKL-DNN to Paddle
if (src_item.layout() == framework::DataLayout::kMKLDNN) {
framework::Tensor out;
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), paddle::platform::get_cur_paddle_data_layout(),
src_item.layout(),
fetch_var_name == framework::GradVarName("Filter")
? framework::DataLayout::kNCHW
: paddle::platform::get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), &dst_item);
} else {
......
......@@ -208,9 +208,8 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar(
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(),
framework::StringToDataLayout(data_format));
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
......@@ -554,16 +553,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN grad does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN Grad does not support NDHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
......@@ -591,6 +581,32 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return type;
}
framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "Input") ||
(var_name == framework::GradVarName("Output"))) &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
template <typename T>
class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
public:
......
......@@ -272,6 +272,10 @@ class ConvOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
};
class ConvOpDoubleGrad : public framework::OperatorWithKernel {
......
......@@ -74,6 +74,11 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
framework::DataLayout from,
framework::DataLayout to) {
// Shape changing makes sense for 3+ dims Tensors
if (tensor_in->dims().size() < 3) {
return;
}
switch (from) {
case framework::DataLayout::kMKLDNN:
if (to == framework::DataLayout::kNHWC) {
......
......@@ -34,6 +34,10 @@ class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
if data_layout != "NCHW" and data_layout != "NHWC":
raise ValueError("Unknown data order.")
# run forward
y, saved_mean, saved_variance = _reference_training(
x, scale, bias, epsilon, data_layout)
......@@ -46,6 +50,12 @@ class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
class TestMKLDNNBatchNormOpTraining_NHWC(TestMKLDNNBatchNormOpTraining):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_formats = ["NHWC"]
class TestMKLDNNBatchNormOpExistedPrimitives(TestMKLDNNBatchNormOpTraining):
def init_test_case(self):
TestMKLDNNBatchNormOpTraining.init_test_case(self)
......
......@@ -208,18 +208,6 @@ class TestConv2dOp_Valid_NHWC_MKLDNN(TestConv2dOp_Valid_MKLDNN):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad(self):
pass
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad_no_filter(self):
pass
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad_no_input(self):
pass
class TestConv2dOp_Same_NHWC_MKLDNN(TestConv2dOp_Valid_NHWC_MKLDNN):
def init_paddings(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册