未验证 提交 36cdb6e2 编写于 作者: J jakpiase 提交者: GitHub

Fix for reshape2 oneDNN op (#35455)

* fix for reshape2

* added reviewers sugestions
上级 ed97be09
......@@ -22,6 +22,25 @@ using paddle::framework::LoDTensor;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
static std::vector<int> extract_shape(
const std::vector<const Tensor*>& list_new_shape_tensor) {
std::vector<int> vec_new_shape;
vec_new_shape.reserve(list_new_shape_tensor.size());
for (const auto& tensor : list_new_shape_tensor) {
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor->dims()));
vec_new_shape.emplace_back(*tensor->data<int32_t>());
}
return vec_new_shape;
}
template <typename T>
class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
public:
......@@ -59,7 +78,11 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
}
if (ctx.Type().find("reshape") != std::string::npos) {
if (ctx.HasInput("Shape")) {
auto list_new_shape_tensor = ctx.MultiInput<Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
auto new_shape = extract_shape(list_new_shape_tensor);
out_dims = ValidateShape(new_shape, x_dims);
} else if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
......
......@@ -72,9 +72,9 @@ class TestReshape2OneDNNOpDimInfer1(TestReshape2OneDNNOp):
class TestReshape2OneDNNOpDimInfer2(TestReshape2OneDNNOp):
def init_data(self):
self.ori_shape = (10, 2, 6)
self.new_shape = (10, 0, 3, -1)
self.infered_shape = (10, 2, 3, -1)
self.ori_shape = (6, 20)
self.new_shape = (0, -1, 20)
self.actual_shape = (2, 3, 20)
def set_additional_inputs(self):
self.inputs["Shape"] = np.array(self.actual_shape, dtype="int32")
......@@ -85,11 +85,6 @@ class TestReshape2OneDNNOpDimInfer2(TestReshape2OneDNNOp):
'XShape': np.random.random(self.ori_shape).astype("float32")
}
def init_data1(self):
self.ori_shape = (6, 20)
self.new_shape = (0, -1, 20)
self.actual_shape = (2, 3, 20)
class TestReshape2OneDNNOp_attr_OnlyShape(TestReshape2OneDNNOp):
def set_additional_inputs(self):
......@@ -119,6 +114,34 @@ class TestReshape2OneDNNOpDimInfer1_attr_OnlyShape(
self.shape = (5, -1, -1)
class TestReshape2OneDNNOpDimInfer1_attr_ShapeTensor(TestReshape2OneDNNOp):
def set_additional_inputs(self):
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs["ShapeTensor"] = shape_tensor
def init_data(self):
self.ori_shape = (5, 20)
self.new_shape = (5, -1, 10)
self.infered_shape = (5, -1, 10)
self.shape = (5, -1, -1)
class TestReshape2OneDNNOpDimInfer1_attr_ShapeTensorAndShape(
TestReshape2OneDNNOpDimInfer1_attr_ShapeTensor):
def set_additional_inputs(self):
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs["Shape"] = np.array((1, 2, 3, 4), dtype="int32")
self.inputs["ShapeTensor"] = shape_tensor
class TestReshapeOneDNNOp(TestReshape2OneDNNOp):
def set_op_type(self):
self.op_type = "reshape"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册