未验证 提交 22c4c189 编写于 作者: J jakpiase 提交者: GitHub

Added reshape, reshape2, squeeze and squeeze2 BF16/FP32 FWD/BWD kernels (#34219)

* test version of matmul_v2

* added matmul_v2 grad kernel

* minor changes

* minor changes

* minor change for CI approval

* CI fix

* CI fix

* added squeeze and squeeze2 kernels

* CI fix

* CI fix

* CI fix

* disabled tests when compiled with cuda

* added setting format_tag by strides

* added sigmoid BF16 FWD/BWD and gelu BF16 BWD

* changes after review

* Revert "added sigmoid BF16 FWD/BWD and gelu BF16 BWD"

This reverts commit 6e3f76720b545abfcff9f6052b46b73a1e745cae.

* Revert "Merge branch 'matmul_v2_grad' into squeeze2_op"

This reverts commit 06fcf67843a4a7884eccdf67a02a03575e1d4cb8, reversing
changes made to 6e3f76720b545abfcff9f6052b46b73a1e745cae.

* minor change

* added reshape1/2 kernels

* moved some functions into private block

* CI fix

* CI fix

* CI fix
上级 e6aacd1e
......@@ -2262,26 +2262,15 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat",
"conv2d",
"conv2d_transpose",
"elementwise_add",
"elementwise_mul",
"fc",
"fusion_gru",
"fusion_lstm",
"gelu",
"layer_norm",
"matmul",
"matmul_v2",
"pool2d",
"prelu",
"relu",
"reshape2",
"softmax",
"split",
"sum",
"transpose2"});
std::unordered_set<std::string>(
{"concat", "conv2d", "conv2d_transpose",
"elementwise_add", "elementwise_mul", "fc",
"fusion_gru", "fusion_lstm", "gelu",
"layer_norm", "matmul", "matmul_v2",
"pool2d", "prelu", "relu",
"reshape2", "softmax", "split",
"squeeze", "squeeze2", "sum",
"transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using paddle::framework::LoDTensor;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
template <typename T>
class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
RunKernel(ctx);
}
private:
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<LoDTensor>("X");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto* out = ctx.Output<LoDTensor>("Out");
framework::DDim x_dims;
// if reshape or squeeze
if (ctx.Type().find("2") == std::string::npos) {
x_dims = x->dims();
} else {
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
auto x_vec_dims = framework::vectorize(x_dims);
framework::DDim out_dims;
if (ctx.Type() == "squeeze") {
auto& axes = ctx.Attr<std::vector<int>>("axes");
out_dims = GetOutputShape(axes, x_dims, true);
} else {
out_dims = out->dims();
}
if (ctx.Type().find("reshape") != std::string::npos) {
if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ValidateShape(shape, x_dims);
}
}
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
std::string key =
platform::CreateKey(dev_ctx, x_vec_dims, x->format(), x_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->format(), platform::to_void_cast(x->data<T>()));
out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, getPlainFormatTag(x), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->Resize(out_dims);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
framework::vectorize(out_dims))));
}
protected:
static mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) {
auto tensor_dims_size = tensor->dims().size();
PADDLE_ENFORCE_EQ(
tensor_dims_size <= 6 && tensor_dims_size >= 1, true,
platform::errors::InvalidArgument(
"Dims for squeeze_grad oneDNN op must be in range <1, 6>"));
switch (tensor_dims_size) {
case 1:
return mkldnn::memory::format_tag::a;
case 2:
return mkldnn::memory::format_tag::ab;
case 3:
return mkldnn::memory::format_tag::abc;
case 4:
return mkldnn::memory::format_tag::abcd;
case 5:
return mkldnn::memory::format_tag::abcde;
default:
return mkldnn::memory::format_tag::abcdef;
}
}
static framework::DDim ValidateShape(const std::vector<int>& shape,
const framework::DDim& in_dims) {
const int64_t in_size = framework::product(in_dims);
auto in_dims_vec = framework::vectorize(in_dims);
bool all_positive = std::all_of(in_dims_vec.cbegin(), in_dims_vec.cend(),
[](int64_t i) { return i > 0; });
// only one dimension can be set to -1, whose size will be automatically
// infered
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ(
unk_dim_idx, -1,
platform::errors::InvalidArgument(
"Only one dimension value of 'shape' in ReshapeOp can "
"be -1. But received shape = [%s], shape[%d] is also -1.",
framework::make_ddim(shape), i));
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(
static_cast<int>(i), in_dims.size(),
platform::errors::InvalidArgument(
"The index of 0 in `shape` must be less than "
"the input tensor X's dimensions. "
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
"X's dimensions = %d.",
framework::make_ddim(shape), i, in_dims, in_dims.size()));
} else {
PADDLE_ENFORCE_GT(
shape[i], 0,
platform::errors::InvalidArgument(
"Each dimension value of 'shape' in ReshapeOp must not "
"be negative except one unknown dimension. "
"But received shape = [%s], shape[%d] = %d.",
framework::make_ddim(shape), i, shape[i]));
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] =
(shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
if (all_positive) {
// in_size < 0 and is un-determinate in compile time, skip the check,
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, in_size = -8, output_shape[0] = 0
// the following check will fail.
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(
output_shape[unk_dim_idx] * capacity, -in_size,
platform::errors::InvalidArgument(
"The 'shape' attribute in ReshapeOp is invalid. "
"The input tensor X'size must be divisible by known "
"capacity of 'shape'. "
"But received X's shape = [%s], X's size = %d, "
"'shape' is [%s], known capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
} else {
output_shape[unk_dim_idx] = -1;
}
} else {
if (all_positive) {
PADDLE_ENFORCE_EQ(
capacity, in_size,
platform::errors::InvalidArgument(
"The 'shape' in ReshapeOp is invalid. "
"The input tensor X'size must be equal to the capacity of "
"'shape'. "
"But received X's shape = [%s], X's size = %d, 'shape' is "
"[%s], the capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
}
}
return framework::make_ddim(output_shape);
}
};
template <typename T>
class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
RunKernel(ctx);
}
private:
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
framework::DDim x_dims;
// if reshape or squeeze
if (ctx.Type().find("2") == std::string::npos) {
x_dims = dx->dims();
} else {
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
auto dout_vec_dims = framework::vectorize(dout->dims());
mkldnn::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout->type());
std::string key =
platform::CreateKey(dev_ctx, dout_vec_dims, this->getPlainFormatTag(dx),
dx->format(), dout_type);
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, this->getPlainFormatTag(dout), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->Resize(x_dims);
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
framework::vectorize(x_dims))));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(squeeze, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(reshape, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(reshape_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(reshape2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(reshape2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -228,9 +228,17 @@ class ReshapeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -269,6 +277,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
"It has the lowest priority compare with Input(Shape) and "
" Input(ShapeTensor).")
.SetDefault({});
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Reshape Operator.
......@@ -334,9 +345,17 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -517,9 +536,17 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -110,9 +110,17 @@ class SqueezeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -129,9 +137,17 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -144,6 +160,14 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<int>). List of integers,"
" indicating the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddComment(R"DOC(
Squeeze Operator.
......@@ -209,6 +233,21 @@ class Squeeze2Op : public framework::OperatorWithKernel {
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
......@@ -243,9 +282,17 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
@OpTestTool.skip_if(core.is_compiled_with_cuda(),
"CUDA has to be skipped because it forces dygraph")
class TestReshape2OneDNNOp(OpTest):
def setUp(self):
self.init_data()
self.set_op_type()
self.x = np.random.random(self.ori_shape).astype("float32")
self.set_inputs()
self.set_additional_inputs()
self.set_attrs()
self.set_outputs()
def set_op_type(self):
self.op_type = "reshape2"
def set_inputs(self):
self.inputs = {"X": self.x}
def set_additional_inputs(self):
pass
def set_attrs(self):
self.attrs = {"shape": self.new_shape, 'use_mkldnn': True}
def set_outputs(self):
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32")
}
def init_data(self):
self.ori_shape = (2, 60)
self.new_shape = (12, 10)
self.infered_shape = (12, 10)
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshape2OneDNNOpDimInfer1(TestReshape2OneDNNOp):
def init_data(self):
self.ori_shape = (5, 25)
self.new_shape = (5, -1, 5)
self.infered_shape = (5, -1, 5)
class TestReshape2OneDNNOpDimInfer2(TestReshape2OneDNNOp):
def init_data(self):
self.ori_shape = (10, 2, 6)
self.new_shape = (10, 0, 3, -1)
self.infered_shape = (10, 2, 3, -1)
def set_additional_inputs(self):
self.inputs["Shape"] = np.array(self.actual_shape, dtype="int32")
def set_outputs(self):
self.outputs = {
"Out": self.inputs["X"].reshape(self.actual_shape),
'XShape': np.random.random(self.ori_shape).astype("float32")
}
def init_data(self):
self.ori_shape = (6, 20)
self.new_shape = (0, -1, 20)
self.actual_shape = (2, 3, 20)
class TestReshape2OneDNNOp_attr_OnlyShape(TestReshape2OneDNNOp):
def set_additional_inputs(self):
self.inputs["Shape"] = np.array(self.new_shape, dtype="int32")
def set_attrs(self):
self.attrs = {'use_mkldnn': True}
def set_outputs(self):
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32")
}
def init_data(self):
self.ori_shape = (4, 25)
self.new_shape = (10, 10)
self.infered_shape = (10, 10)
class TestReshape2OneDNNOpDimInfer1_attr_OnlyShape(
TestReshape2OneDNNOp_attr_OnlyShape):
def init_data(self):
self.ori_shape = (5, 20)
self.new_shape = (5, -1, 10)
self.infered_shape = (5, -1, 10)
self.shape = (5, -1, -1)
class TestReshapeOneDNNOp(TestReshape2OneDNNOp):
def set_op_type(self):
self.op_type = "reshape"
def set_outputs(self):
self.outputs = {"Out": self.inputs["X"].reshape(self.infered_shape)}
def test_check_output(self):
self.check_output()
class TestReshapeOneDNNOpDimInfer1(TestReshapeOneDNNOp):
def init_data(self):
self.ori_shape = (5, 25)
self.new_shape = (5, -1, 5)
self.infered_shape = (5, -1, 5)
class TestReshapeOneDNNOp_attr_OnlyShape(TestReshape2OneDNNOp_attr_OnlyShape):
def set_op_type(self):
self.op_type = "reshape"
def set_outputs(self):
self.outputs = {"Out": self.inputs["X"].reshape(self.infered_shape)}
def test_check_output(self):
self.check_output()
class TestReshapeOneDNNOpDimInfer1_attr_OnlyShape(
TestReshapeOneDNNOp_attr_OnlyShape):
def init_data(self):
self.ori_shape = (5, 20)
self.new_shape = (5, -1, 10)
self.infered_shape = (5, -1, 10)
self.shape = (5, -1, -1)
# BF16 TESTS
def create_reshape_bf16_test_classes(parent):
@OpTestTool.skip_if_not_cpu_bf16()
class TestReshape2BF16OneDNNOp(parent):
def set_inputs(self):
self.dtype = np.uint16
self.inputs = {"X": convert_float_to_uint16(self.x)}
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.reshape(self.dout, self.ori_shape)
def test_check_output(self):
self.check_output_with_place(
core.CPUPlace(), no_check_set=["XShape"])
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[self.dout])
cls_name = "{0}_{1}".format(parent.__name__, "Reshape2_BF16")
TestReshape2BF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestReshape2BF16OneDNNOp
class TestReshapeBF16OneDNNOp(TestReshape2BF16OneDNNOp):
def set_op_type(self):
self.dtype = np.uint16
self.op_type = "reshape"
def set_outputs(self):
self.outputs = {"Out": self.x.reshape(self.new_shape)}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
cls_name = "{0}_{1}".format(parent.__name__, "Reshape_BF16")
TestReshapeBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestReshapeBF16OneDNNOp
create_reshape_bf16_test_classes(TestReshape2OneDNNOp)
create_reshape_bf16_test_classes(TestReshape2OneDNNOpDimInfer1)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
@OpTestTool.skip_if(core.is_compiled_with_cuda(),
"CUDA has to be skipped because it forces dygraph")
class TestSqueeze2OneDNNOp(OpTest):
def set_op_type(self):
self.op_type = "squeeze2"
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, 2)
self.new_shape = (3, 40)
def set_inputs(self):
self.inputs = {"X": self.x}
def init_attrs(self):
self.attrs = {"axes": self.axes, 'use_mkldnn': True}
def set_outputs(self):
self.outputs = {
"Out": self.x.reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def setUp(self):
self.set_op_type()
self.init_test_case()
self.x = np.random.random(self.ori_shape).astype("float32")
self.set_inputs()
self.init_attrs()
self.set_outputs()
def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad_with_place(core.CPUPlace(), ["X"], "Out")
class TestSqueezeOneDNNOp(TestSqueeze2OneDNNOp):
def set_op_type(self):
self.op_type = "squeeze"
def set_outputs(self):
self.outputs = {"Out": self.x.reshape(self.new_shape)}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
class TestSqueeze2OneDNNOp1(TestSqueeze2OneDNNOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = (0, -2)
self.new_shape = (20, 5)
class TestSqueezeOneDNNOp1(TestSqueezeOneDNNOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = (0, -2)
self.new_shape = (20, 5)
class TestSqueeze2OneDNNOp2(TestSqueeze2OneDNNOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
self.new_shape = (20, 5)
class TestSqueezeOneDNNOp2(TestSqueezeOneDNNOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
self.new_shape = (20, 5)
class TestSqueeze2OneDNNOp3(TestSqueeze2OneDNNOp):
def init_test_case(self):
self.ori_shape = (25, 1, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (25, 1, 4)
class TestSqueezeOneDNNOp3(TestSqueezeOneDNNOp):
def init_test_case(self):
self.ori_shape = (25, 1, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (25, 1, 4)
# BF16 TESTS
def create_squeeze_bf16_test_classes(parent):
@OpTestTool.skip_if_not_cpu_bf16()
class TestSqueeze2BF16OneDNNOp(parent):
def set_inputs(self):
self.dtype = np.uint16
self.inputs = {"X": convert_float_to_uint16(self.x)}
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.reshape(self.dout, self.ori_shape)
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[self.dout])
cls_name = "{0}_{1}".format(parent.__name__, "Squeeze2_BF16")
TestSqueeze2BF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestSqueeze2BF16OneDNNOp
class TestSqueezeBF16OneDNNOp(TestSqueeze2BF16OneDNNOp):
def set_op_type(self):
self.dtype = np.uint16
self.op_type = "squeeze"
def set_outputs(self):
self.outputs = {"Out": self.x.reshape(self.new_shape)}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
cls_name = "{0}_{1}".format(parent.__name__, "Squeeze_BF16")
TestSqueezeBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestSqueezeBF16OneDNNOp
create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp)
create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp1)
create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp2)
create_squeeze_bf16_test_classes(TestSqueeze2OneDNNOp3)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册