未验证 提交 d47690b2 编写于 作者: J Jacek Czaja 提交者: GitHub

shape mkldnn kernel adapted to NHWC (#42548)

* - shape mkldnn adapted to NHWC

- NHWC shape mkldnn ut

- fixes to UT

- Fix to UT

- Fixes to UT

- Fix of compilation

* - lint candidate fix
上级 08158e93
cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op activation_op pooling transpose_op scope device_context enforce executor) cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op shape_op activation_op pooling transpose_op scope device_context enforce executor)
...@@ -32,6 +32,16 @@ class ShapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -32,6 +32,16 @@ class ShapeMKLDNNKernel : public framework::OpKernel<T> {
in_dims = in_var->Get<phi::SelectedRows>().value().dims(); in_dims = in_var->Get<phi::SelectedRows>().value().dims();
} else { } else {
in_dims = in_var->Get<LoDTensor>().dims(); in_dims = in_var->Get<LoDTensor>().dims();
// Output of shape op is often fed as input to fill_constant ops
// and we need to rotate a shape otherwise Tensors of wrong shape may be
// allocated
if (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC &&
in_dims.size() >= 3) {
auto rdims = phi::vectorize<int>(in_dims);
std::rotate(rdims.begin() + 1, rdims.begin() + 2, rdims.end());
in_dims = phi::make_ddim(rdims);
}
} }
auto* out_t = ctx.Output<Tensor>("Out"); auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()}); out_t->Resize({in_dims.size()});
......
...@@ -32,9 +32,12 @@ USE_OP_ITSELF(relu); ...@@ -32,9 +32,12 @@ USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP_ITSELF(transpose); USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN); USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
USE_OP_ITSELF(shape);
USE_OP_DEVICE_KERNEL(shape, MKLDNN);
PD_DECLARE_KERNEL(pool2d, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(pool2d, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(relu, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(relu, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(shape, CPU, ALL_LAYOUT);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -154,5 +157,59 @@ TEST(test_pool2d_relu_relu_nhwc, cpu_place) { ...@@ -154,5 +157,59 @@ TEST(test_pool2d_relu_relu_nhwc, cpu_place) {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Computed shape does not match expected shape")); "Computed shape does not match expected shape"));
} }
TEST(test_pool2d_shape_nhwc, cpu_place) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape
std::vector<int32_t> expected_dims{1, 3, 7, 512}; // NHWC expected shape
platform::CPUPlace p;
framework::Scope scope;
InputVars input_name = {"x",
scope.Var("x")->GetMutable<framework::LoDTensor>()};
// Initialize input data
std::uniform_real_distribution<float> dist(static_cast<float>(10.0),
static_cast<float>(20.0));
std::mt19937 engine;
size_t numel = static_cast<size_t>(phi::product(dims));
input_name.tensor->Resize(dims);
auto data_ptr = input_name.tensor->mutable_data<float>(p);
for (size_t i = 0; i < numel; ++i) {
data_ptr[i] = dist(engine);
}
scope.Var("y")->GetMutable<framework::LoDTensor>();
auto *z = scope.Var("z")->GetMutable<framework::LoDTensor>();
auto &pool = platform::DeviceContextPool::Instance();
// Make pool2d followed by shape. shape for NHWC should return
// as output tensor not-rotated shape of Pool (
auto ksize = std::vector<int>(2, 2);
auto op_pool = framework::OpRegistry::CreateOp(
"pool2d", {{"X", {"x"}}}, {{"Out", {"y"}}},
{{"pooling_type", {std::string("max")}},
{"ksize", {ksize}},
{"data_format", {std::string("NHWC")}},
{"use_mkldnn", {true}}});
auto op_shape = framework::OpRegistry::CreateOp(
"shape", {{"Input", {"y"}}}, {{"Out", {"z"}}}, {{"use_mkldnn", {true}}});
op_pool->Run(scope, p);
op_shape->Run(scope, p);
pool.Get(p)->Wait();
// repack tensor data into vector for easy comparison
auto *zdata = z->data<int32_t>();
std::vector<int32_t> vzdata(zdata, zdata + z->numel());
// Verify shape of output
PADDLE_ENFORCE_EQ(vzdata, expected_dims,
platform::errors::InvalidArgument(
"Computed shape does not match expected shape"));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册