提交 03dda317 编写于 作者: B bingyanghuang 提交者: Tao Luo

[cherry-pick] Refactor mkldnn eletwise_mul and error message for NHWC in mkldnn (#21361)

上级 93c7f058
...@@ -148,6 +148,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -148,6 +148,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN does not support NDHWC data format yet"));
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
customized_type_value = customized_type_value =
...@@ -521,6 +530,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -521,6 +530,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN grad does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN Grad does not support NDHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32; customized_type_value = kConvMKLDNNFP32;
...@@ -695,14 +714,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( ...@@ -695,14 +714,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
}
#endif #endif
auto type = framework::OpKernelType( auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
......
...@@ -127,6 +127,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -127,6 +127,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
"Conv Transpose MKLDNN does not support NHWC data format yet");
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -4,3 +4,4 @@ register_operators() ...@@ -4,3 +4,4 @@ register_operators()
cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor)
cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_mul_op_correct_dims SRCS test_elementwise_mul_op_dim.cc DEPS op_registry elementwise_mul_op scope device_context enforce executor)
...@@ -116,7 +116,7 @@ class ElementwiseMulDoubleGradDescMaker ...@@ -116,7 +116,7 @@ class ElementwiseMulDoubleGradDescMaker
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseMulOp,
ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMulOpGradDescMaker); ops::ElementwiseMulOpGradDescMaker);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad, REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad,
......
...@@ -13,14 +13,59 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,59 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ElementwiseMulOp : public ElementwiseOp {
public:
using Tensor = framework::Tensor;
using ElementwiseOp::ElementwiseOp;
#ifdef PADDLE_WITH_MKLDNN
static bool AreDimsAndFormatCorrect(const framework::ExecutionContext& ctx,
int simd_width,
mkldnn::memory::format x_format) {
using Tensor = framework::Tensor;
using paddle::framework::vectorize;
using mkldnn::memory;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = vectorize(x->dims());
const bool are_dims_divisable = !(x_dims[1] % simd_width);
const bool is_x_format_correct = x->format() == x_format;
const bool is_y_format_correct = vectorize(y->dims()).size() == 2;
return are_dims_divisable && is_x_format_correct && is_y_format_correct;
}
#endif
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
if (platform::CanMKLDNNBeUsed(ctx)) {
bool can_use_avx512_kernel =
platform::MayIUse(platform::avx512f) &&
AreDimsAndFormatCorrect(ctx, 16, memory::format::nChw16c);
if (can_use_avx512_kernel) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_mul(const framework::ExecutionContext& ctx, void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
......
...@@ -120,21 +120,10 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,21 +120,10 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.") AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>("x_data_format", "This parameter is no longer used.")
"x_data_format",
"(string, default NCHW) Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault(""); .SetDefault("");
AddAttr<std::string>( AddAttr<std::string>("y_data_format", "This parameter is no longer used.")
"y_data_format",
"(string, default \"\") Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault(""); .SetDefault("");
AddOpComment(); AddOpComment();
} }
......
...@@ -32,38 +32,28 @@ using framework::DataLayout; ...@@ -32,38 +32,28 @@ using framework::DataLayout;
using mkldnn::memory; using mkldnn::memory;
using platform::StringToMKLDNNFormat; using platform::StringToMKLDNNFormat;
static void UpdateDataFormat(const framework::ExecutionContext& ctx, template <typename T>
framework::Tensor* tensor, const char* attribute) { static void ComputeBroadcastedMultiply(const T* x_data, const T* y_data,
if (ctx.op().HasAttr(attribute)) { T* z_data, int64_t n, int64_t c,
auto format_as_string = ctx.Attr<std::string>(attribute); int64_t h, int64_t w, int simd_width,
auto format = StringToMKLDNNFormat(&format_as_string); void (*multiply)(const T*, const T*, T*,
if (format != MKLDNNMemoryFormat::any) { int, int)) {
tensor->set_format(format); const int64_t C = c / simd_width;
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply(ptr_x, ptr_y, ptr_z, h, w);
} }
} }
} }
template <typename T>
static void ReorderInput(framework::Tensor* tensor,
const platform::Place& place,
const mkldnn::engine& engine, bool isFourDim) {
using platform::to_void_cast;
auto dims = paddle::framework::vectorize<int>(tensor->dims());
framework::Tensor out_tensor;
out_tensor.Resize(tensor->dims());
out_tensor.set_format(isFourDim ? MKLDNNMemoryFormat::nchw
: MKLDNNMemoryFormat::nc);
out_tensor.set_layout(tensor->layout());
mkldnn::memory input_memory = {
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
to_void_cast<T>(tensor->data<T>())};
mkldnn::memory output_memory = {
{{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
platform::Reorder(input_memory, output_memory);
tensor->ShareDataWith(out_tensor);
}
template <typename T> template <typename T>
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
public: public:
...@@ -82,103 +72,26 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -82,103 +72,26 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims_untrimmed = y->dims(); auto y_dims_untrimmed = y->dims();
auto x_int_dims = paddle::framework::vectorize<int>(x_dims); auto x_int_dims = paddle::framework::vectorize<int>(x_dims);
UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format"); int pre, num, post, is_run_common_broadcast;
UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format"); get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &num, &post,
&is_run_common_broadcast);
const bool is_avx512_enabled = platform::MayIUse(platform::avx512f);
const bool are_dims_divisable = !(x_int_dims[1] % 16); if (post == 1) PADDLE_THROW("Not implemented when post is 1");
const bool is_x_format_correct = x->format() == MKLDNNMemoryFormat::nChw16c;
const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc; const int64_t n = x_dims[0];
if (is_x_format_correct && is_y_format_correct && are_dims_divisable && const int64_t c = x_dims[1];
is_avx512_enabled) { const int64_t h = x_dims[2];
int pre, n, post; const int64_t w = x_dims[3];
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
const int simd_width = 16;
if (post == 1) { auto multiply =
PADDLE_THROW("Not implemented when post is 1"); jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>, platform::CPUPlace>::Cache()
} else { .At(0);
// Just check whether it works for RE-Resnext. ComputeBroadcastedMultiply(x_data, y_data, z_data, n, c, h, w, simd_width,
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); multiply);
int n = x_dims[0]; z->set_layout(DataLayout::kMKLDNN);
int c = x_dims[1]; z->set_format(x->format());
int h = x_dims[2];
int w = x_dims[3];
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
"Y should be in nc format");
constexpr int simd_width = 16;
int C = c / simd_width;
auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>,
platform::CPUPlace>::Cache()
.At(0);
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply(ptr_x, ptr_y, ptr_z, h, w);
}
}
}
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
} else {
// Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw = x->format() == MKLDNNMemoryFormat::nchw;
const bool is_x_nc = x->format() == MKLDNNMemoryFormat::nc;
const bool is_x_x = x->format() == MKLDNNMemoryFormat::x;
const bool is_y_nchw = y->format() == MKLDNNMemoryFormat::nchw;
const bool is_y_nc = y->format() == MKLDNNMemoryFormat::nc;
const bool is_y_x = y->format() == MKLDNNMemoryFormat::x;
if (!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
if (!(is_x_nchw || is_x_nc || is_x_x))
ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
x->dims().size() == 4);
if (!(is_y_nchw || is_y_nc || is_y_x))
ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
y->dims().size() == 4);
}
auto mul_func = [](T a, T b) -> T { return a * b; };
TransformFunctor<decltype(mul_func), T,
paddle::platform::CPUDeviceContext, T>
functor(
x, y, z,
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
mul_func);
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
functor.RunRowWise(n, pre);
} else {
functor.RunMidWise(n, pre, post);
}
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
}
} }
}; };
} // namespace operators } // namespace operators
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
USE_OP(elementwise_mul);
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_MKLDNN
using framework::Scope;
using framework::LoDTensor;
using framework::OpRegistry;
using framework::OperatorBase;
using framework::RuntimeContext;
using framework::ExecutionContext;
struct TestData {
int64_t channel_num;
MKLDNNMemoryFormat format;
framework::DDim y_dims;
bool supposed_to_fail;
};
void MainTest(const TestData& test_data) {
auto place = platform::CPUPlace();
Scope scope;
auto* x = scope.Var("x")->GetMutable<LoDTensor>();
auto* y = scope.Var("y")->GetMutable<LoDTensor>();
scope.Var("out")->GetMutable<LoDTensor>();
x->Resize({1, test_data.channel_num, 3, 3});
y->Resize(test_data.y_dims);
x->set_format(test_data.format);
y->set_format(MKLDNNMemoryFormat::nc);
std::unique_ptr<OperatorBase> op = OpRegistry::CreateOp(
"elementwise_mul", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {});
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
RuntimeContext runtime_ctx =
RuntimeContext(op->Inputs(), op->Outputs(), scope);
ExecutionContext ctx =
ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr);
bool result = ElementwiseMulOp::AreDimsAndFormatCorrect(
ctx, 16, MKLDNNMemoryFormat::nChw16c);
if (test_data.supposed_to_fail)
ASSERT_FALSE(result);
else
ASSERT_TRUE(result);
}
// Checks if AreDimsAndFormatCorrect returns true when supplied with expected
// data
TEST(ElementwiseMulOpTester, correct_dims) {
TestData test_data;
test_data.channel_num = 16;
test_data.format = MKLDNNMemoryFormat::nChw16c;
test_data.y_dims = {1, test_data.channel_num};
test_data.supposed_to_fail = false;
MainTest(test_data);
}
// Checks if AreDimsAndFormatCorrect fails when channel_num is not divisable by
// 16
TEST(ElementwiseMulOpTester, incorrect_channel_num) {
TestData test_data;
test_data.channel_num = 17;
test_data.format = MKLDNNMemoryFormat::nChw16c;
test_data.y_dims = {1, test_data.channel_num};
test_data.supposed_to_fail = true;
MainTest(test_data);
}
// Checks if AreDimsAndFormatCorrect fails when x format is different from
// nchw16c
TEST(ElementwiseMulOpTester, incorrect_format) {
TestData test_data;
test_data.channel_num = 16;
test_data.format = MKLDNNMemoryFormat::nchw;
test_data.y_dims = {1, test_data.channel_num};
test_data.supposed_to_fail = true;
MainTest(test_data);
}
// Checks if AreDimsAndFormatCorrect fails when y input is not 2-dimensional
TEST(ElementwiseMulOpTester, incorrect_y_dims) {
TestData test_data;
test_data.channel_num = 16;
test_data.format = MKLDNNMemoryFormat::nChw16c;
test_data.y_dims = {1, test_data.channel_num, 1};
test_data.supposed_to_fail = true;
MainTest(test_data);
}
#endif
} // namespace operators
} // namespace paddle
...@@ -193,6 +193,12 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -193,6 +193,12 @@ class LRNOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -311,6 +317,12 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -311,6 +317,12 @@ class LRNOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"LRN MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -129,6 +129,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -129,6 +129,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -160,6 +166,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -160,6 +166,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
platform::errors::Unimplemented(
"Pool MKLDNN grad does not support NHWC data format yet"));
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -40,7 +40,7 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): ...@@ -40,7 +40,7 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp):
self.filter_size2 = [1, 16, 2, 2] self.filter_size2 = [1, 16, 2, 2]
self.dilations = [1, 1] self.dilations = [1, 1]
self.use_cudnn = False self.use_cudnn = False
self.data_format = "NCHW" self.data_format = "ANYLAYOUT"
self.input = np.random.random(self.input_size).astype(self.dtype) self.input = np.random.random(self.input_size).astype(self.dtype)
self.filter = np.random.random(self.filter_size).astype(self.dtype) self.filter = np.random.random(self.filter_size).astype(self.dtype)
self.filter2 = np.random.random(self.filter_size2).astype(self.dtype) self.filter2 = np.random.random(self.filter_size2).astype(self.dtype)
...@@ -97,7 +97,8 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): ...@@ -97,7 +97,8 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp):
'groups': self.groups, 'groups': self.groups,
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn 'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}) })
elementwise_mul_op = block.append_op( elementwise_mul_op = block.append_op(
type="elementwise_mul", type="elementwise_mul",
...@@ -152,179 +153,5 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): ...@@ -152,179 +153,5 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y.reshape(1, 16, 1, 1)
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw16c"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw16c"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = self.x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp()
self.attrs["x_data_format"] = "nchw"
self.attrs["y_data_format"] = "nchw16c"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
def init_input_output(self):
self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * self.y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp()
self.attrs["x_data_format"] = "nc"
self.attrs["y_data_format"] = "nc"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp from paddle.fluid.tests.unittests.test_lrn_op import TestLRNOp
import paddle.fluid as fluid
class TestLRNMKLDNNOp(TestLRNOp): class TestLRNMKLDNNOp(TestLRNOp):
...@@ -47,5 +48,20 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): ...@@ -47,5 +48,20 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
self.assertRaises(AttributeError, check_raise_is_test) self.assertRaises(AttributeError, check_raise_is_test)
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
def init_test_case(self):
self.data_format = 'NHWC'
def test_check_output(self):
pass
# Grad tests both FWD and BWD ops kernels creation
def test_check_grad_normal(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -141,5 +141,26 @@ class TestAsymPadValid(TestAsymPad): ...@@ -141,5 +141,26 @@ class TestAsymPadValid(TestAsymPad):
self.padding_algorithm = "VALID" self.padding_algorithm = "VALID"
# Designed to Fail
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class TestAsymPadValidNHWC(TestAsymPadValid):
def init_data_format(self):
self.data_format = "NHWC"
def init_shape(self):
self.shape = [2, 7, 7, 3]
def test_check_output(self):
pass
# Grad tests both FWD and BWD ops kernels creation
# GetExpectedKernelType should throw an exception on lack of support
# to NHWC inputs in pool mkldnn kernel
def test_check_grad(self):
with self.assertRaises(fluid.core_avx.EnforceNotMet):
super(TestAsymPadValidNHWC, self).test_check_grad()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册