提交 cf081851 编写于 作者: C caoying03

fix bugs and complete codes.

上级 a8cdd97e
...@@ -25,39 +25,28 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -25,39 +25,28 @@ class ReshapeOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null."); "Input(X) of ReshapeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null."); "Output(Out) of ReshapeOp should not be null.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape"); const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"), PADDLE_ENFORCE(!shape.empty(),
"The shape information can only be set by Attr(shape) or " "The shape information must be set by Attr(shape).");
"by Input(Shape). Attr(shape) and Input(Shape) cannot be "
"set at the same time.");
std::vector<int64_t> output_shape;
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
bool need_copy_dim = ValidateShape(shape, x_dims, output_shape);
if (ctx->HasInput("Shape")) { if (need_copy_dim) {
// The shape information in given by Input(Shape). // Some dimensions can only be determined during runtime. Here temporarily
auto shape_dims = ctx->GetInputDim("Shape"); // set output tensor's shape the same as that of the input tensor.
PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL,
"The Input(Label) should be a 2-D tensor with the 1st "
"dimensions fixed to 1 (a row vector).");
// The actual output shape will be set at runtime, here temporially set
// the shape of output the same as the shape of input.
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
} else { } else {
// The shape information in given by Attr(shape). ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
std::vector<int64_t> output_shape;
ValidateShape(shape, framework::product(x_dims), output_shape);
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
// FIXME(caoying): When shape of the output tensor is determined during
// runtime, LoD information of X will not passed to the output.
if (shape[0] == x_dims[0]) { if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X) // Only pass LoD when the first dimension of output and Input(X)
// are the same. // are the same.
...@@ -67,41 +56,51 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -67,41 +56,51 @@ class ReshapeOp : public framework::OperatorWithKernel {
} }
private: private:
void ValidateShape(const std::vector<int> &shape, const int64_t in_size, bool ValidateShape(const std::vector<int> &shape,
const framework::DDim &input_dim,
std::vector<int64_t> &output_shape) const { std::vector<int64_t> &output_shape) const {
std::vector<size_t> neg_dims_idx; // only one dimension canbe set to -1, whose size will be automatically
const int unknown_index = -1; // only one dimension canbe set to -1, whose // infered.
// size will be automatically infered. const int64_t unknown_index = -1;
const auto in_size = framework::product(input_dim);
const auto x_rank = input_dim.size();
bool need_dim_copy = false;
std::vector<size_t> neg_dims_idx;
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 1 || shape[i] == unknown_index, PADDLE_ENFORCE(shape[i] >= 0 || shape[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or " "Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1."); "only one input dimension can be -1.");
if (shape[i] == unknown_index) neg_dims_idx.push_back(i); if (shape[i] == unknown_index) {
neg_dims_idx.push_back(i);
} else if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i, x_rank,
"Only dimension less than rank of Input(X) can be set to 0.");
need_dim_copy = true;
}
} }
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1, neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) may be unknown."); "Only one input dimension of Attr(shape) may be unknown.");
output_shape.resize(shape.size(), 0);
std::transform(shape.begin(), shape.end(), output_shape.begin(),
[](int a) { return static_cast<int64_t>(a); });
// some dimension can only be determinted during runtime.
if (need_dim_copy) return need_dim_copy;
int64_t inferred_dim = 0; int64_t inferred_dim = 0;
if (neg_dims_idx.size()) { if (neg_dims_idx.size()) {
int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>()); std::multiplies<int>());
inferred_dim = in_size / (-capacity); inferred_dim = in_size / (-capacity);
PADDLE_ENFORCE_EQ(inferred_dim * (-capacity), in_size,
"Invalid shape is given.");
output_shape[neg_dims_idx[0]] = inferred_dim;
} }
return false;
output_shape.resize(shape.size(), 0);
std::transform(shape.begin(), shape.end(), output_shape.begin(),
[](int a) { return static_cast<int64_t>(a); });
if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim;
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
} }
}; };
...@@ -110,14 +109,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,14 +109,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator."); AddInput("X", "The input tensor of reshape operator.");
AddInput(
"Shape",
"Tensor<int64_t>, a 1-D tensor that provides the shape information.")
.AsDispensable();
AddOutput("Out", "The output tensor of reshape operator."); AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"shape", "(std::vector<int>) Target shape of reshape operator.") "shape", "(std::vector<int>) Target shape of reshape operator.");
.SetDefault(std::vector<int>());
AddAttr<bool>("inplace", AddAttr<bool>("inplace",
"Change the source tensor's shape without copy memory.") "Change the source tensor's shape without copy memory.")
.SetDefault(true); .SetDefault(true);
...@@ -153,14 +147,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -153,14 +147,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -27,17 +27,8 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -27,17 +27,8 @@ class ReshapeKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto* shape = ctx.Input<framework::Tensor>("Shape"); auto out_dims =
framework::DDim out_dims; ValidateShape(ctx.Attr<std::vector<int>>("shape"), in->dims());
if (shape) {
std::vector<int64_t> output_shape;
ValidateShape(*shape, framework::product(in->dims()), output_shape);
out_dims = framework::make_ddim(output_shape);
} else {
out_dims = out->dims();
}
bool inplace = ctx.Attr<bool>("inplace"); bool inplace = ctx.Attr<bool>("inplace");
if (!inplace) { if (!inplace) {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -50,35 +41,31 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -50,35 +41,31 @@ class ReshapeKernel : public framework::OpKernel<T> {
} }
private: private:
void ValidateShape(const framework::Tensor& shape, const int64_t in_size, framework::DDim ValidateShape(const std::vector<int> shape_attr,
std::vector<int64_t>& output_shape) const { const framework::DDim& in_dims) const {
std::vector<size_t> neg_dims_idx; const int64_t in_size = framework::product(in_dims);
const int unknown_index = -1; // only one dimension canbe set to -1, whose // only one dimension canbe set to -1, whose size will be automatically
// size will be automatically infered. // infered.
const int64_t unknown_index = -1;
const int64_t dimension = shape.dims()[1];
std::cout << "dimension =" << dimension << std::endl; std::vector<int64_t> output_shape(shape_attr.size(), 0);
const T* shape_data = shape.data<T>();
for (int64_t i = 0; i < dimension; ++i) {
PADDLE_ENFORCE(shape_data[i] > 1 || shape_data[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape_data[i] == unknown_index) neg_dims_idx.push_back(i);
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) can be unknown.");
int64_t capacity = 1; int64_t capacity = 1;
output_shape.resize(dimension, 0); int neg_dim_idx = -1;
for (int64_t i = 0; i < dimension; ++i) { for (size_t i = 0; i < shape_attr.size(); ++i) {
capacity *= shape_data[i]; if (shape_attr[i] == unknown_index) neg_dim_idx = i;
output_shape[i] = static_cast<int64_t>(shape_data[i]); capacity *= (shape_attr[i] ? shape_attr[i] : in_dims[i]);
output_shape[i] =
(shape_attr[i] ? static_cast<int64_t>(shape_attr[i]) : in_dims[i]);
} }
if (neg_dims_idx.size()) if (neg_dim_idx != -1) {
output_shape[neg_dims_idx[0]] = in_size / (-capacity); output_shape[neg_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[neg_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
} }
}; };
......
...@@ -19,7 +19,6 @@ from layer_function_generator import generate_layer_fn ...@@ -19,7 +19,6 @@ from layer_function_generator import generate_layer_fn
from layer_function_generator import autodoc from layer_function_generator import autodoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
import tensor import tensor
import ops
import nn import nn
import math import math
...@@ -58,7 +57,7 @@ def detection_output(loc, ...@@ -58,7 +57,7 @@ def detection_output(loc,
This operation is to get the detection results by performing following This operation is to get the detection results by performing following
two steps: two steps:
1. Decode input bounding box predictions according to the prior boxes. 1. Decode input bounding box predictions according to the prior boxes.
2. Get the final detection results by applying multi-class non maximum 2. Get the final detection results by applying multi-class non maximum
suppression (NMS). suppression (NMS).
...@@ -458,7 +457,7 @@ def ssd_loss(location, ...@@ -458,7 +457,7 @@ def ssd_loss(location,
num, num_prior, num_class = confidence.shape num, num_prior, num_class = confidence.shape
def __reshape_to_2d(var): def __reshape_to_2d(var):
return ops.reshape(x=var, shape=[-1, var.shape[-1]]) return nn.reshape(x=var, shape=[-1, var.shape[-1]])
# 1. Find matched boundding box by prior box. # 1. Find matched boundding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
...@@ -469,7 +468,7 @@ def ssd_loss(location, ...@@ -469,7 +468,7 @@ def ssd_loss(location,
# 2. Compute confidence for mining hard examples # 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices # 2.1. Get the target label based on matched indices
gt_label = ops.reshape(x=gt_label, shape=gt_label.shape + (1, )) gt_label = nn.reshape(x=gt_label, shape=gt_label.shape + (1, ))
target_label, _ = target_assign( target_label, _ = target_assign(
gt_label, matched_indices, mismatch_value=background_label) gt_label, matched_indices, mismatch_value=background_label)
# 2.2. Compute confidence loss. # 2.2. Compute confidence loss.
...@@ -480,7 +479,7 @@ def ssd_loss(location, ...@@ -480,7 +479,7 @@ def ssd_loss(location,
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
# 3. Mining hard examples # 3. Mining hard examples
conf_loss = ops.reshape(x=conf_loss, shape=(num, num_prior)) conf_loss = nn.reshape(x=conf_loss, shape=(num, num_prior))
neg_indices = helper.create_tmp_variable(dtype='int32') neg_indices = helper.create_tmp_variable(dtype='int32')
dtype = matched_indices.dtype dtype = matched_indices.dtype
updated_matched_indices = helper.create_tmp_variable(dtype=dtype) updated_matched_indices = helper.create_tmp_variable(dtype=dtype)
...@@ -548,7 +547,7 @@ def ssd_loss(location, ...@@ -548,7 +547,7 @@ def ssd_loss(location,
# 5.3 Compute overall weighted loss. # 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number. # reshape to [N, Np], N is the batch size and Np is the prior box number.
loss = ops.reshape(x=loss, shape=[-1, num_prior]) loss = nn.reshape(x=loss, shape=[-1, num_prior])
loss = nn.reduce_sum(loss, dim=1, keep_dim=True) loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
if normalize: if normalize:
normalizer = nn.reduce_sum(target_loc_weight) normalizer = nn.reduce_sum(target_loc_weight)
...@@ -696,7 +695,7 @@ def multi_box_head(inputs, ...@@ -696,7 +695,7 @@ def multi_box_head(inputs,
new_shape = [ new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)]) -1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
] ]
out = ops.reshape(x=input, shape=new_shape) out = nn.reshape(x=input, shape=new_shape)
return out return out
def _is_list_or_tuple_(data): def _is_list_or_tuple_(data):
...@@ -793,7 +792,7 @@ def multi_box_head(inputs, ...@@ -793,7 +792,7 @@ def multi_box_head(inputs,
mbox_loc.shape[0], mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4 mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
] ]
mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape) mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten) mbox_locs.append(mbox_loc_flatten)
# get conf_loc # get conf_loc
...@@ -809,7 +808,7 @@ def multi_box_head(inputs, ...@@ -809,7 +808,7 @@ def multi_box_head(inputs,
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3] / num_classes, num_classes conf_loc.shape[3] / num_classes, num_classes
] ]
conf_loc_flatten = ops.reshape(conf_loc, shape=new_shape) conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape)
mbox_confs.append(conf_loc_flatten) mbox_confs.append(conf_loc_flatten)
if len(box_results) == 1: if len(box_results) == 1:
......
...@@ -70,6 +70,7 @@ __all__ = [ ...@@ -70,6 +70,7 @@ __all__ = [
'smooth_l1', 'smooth_l1',
'one_hot', 'one_hot',
'autoincreased_step_counter', 'autoincreased_step_counter',
'reshape',
] ]
...@@ -3184,6 +3185,8 @@ def one_hot(input, depth): ...@@ -3184,6 +3185,8 @@ def one_hot(input, depth):
The one-hot tensor or LodTensor, same as input. The one-hot tensor or LodTensor, same as input.
Examples: Examples:
.. code-block:: python
X is a LoDTensor: X is a LoDTensor:
X.lod = [[0, 1, 4]] X.lod = [[0, 1, 4]]
X.shape = [4, 1] X.shape = [4, 1]
...@@ -3236,3 +3239,56 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -3236,3 +3239,56 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
counter.stop_gradient = True counter.stop_gradient = True
return counter return counter
def reshape(x, shape, act=None, inplace=True, name=None):
"""
Gives a new shape to Tensor without changing its data.
This layer takes a tensor as input and the attribute shape specifying the
new shape. The shape attribute must be specified. At most one dimension of
the new shape can be -1. In this case, the value is inferred from the size
of the tensor and the remaining dimensions. A dimension could also be 0,
in which case the actual dimension value is going to be copied from the
input tensor.
Args:
input(variable): The input tensor.
shape(list): The new shape. At most one dimension of the new shape can
be -1.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, a new output tensor is created
whose data is copied from input x, otherwise the output
shares data with input without copying.
Returns(variable): The output tensor.
Examples:
.. code-block:: python
Given a 2-D tensor X with shape [2 x 2], and the new shape: [1, 4].
The reshape layer will change tensor X into a 2-D tensor with
shape [1 x 4] with its data unchanged.
Given a 3-D tensor x with shape [2, 3, 4] and the new shape: [3, -1].
The reshape layer will change tensor X into a 2-D tensor with shape:
[3 x 8] with its data unchanged.
Given a 3-D tensor x with shape [2, 3, 8] and the new shape:
[-1, 0, 2, 2]. The reshape layer will change tensor X into a 4-D tensor
with shape [4, 3, 2, 2] with its data unchanged.
"""
if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python lsit or tuple.")
helper = LayerHelper("reshape", **locals())
reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reshape",
inputs={"X": x},
attrs={"shape": shape,
"inplace": inplace},
outputs={"Out": reshaped})
return helper.append_activation(reshaped)
...@@ -47,7 +47,6 @@ __activations__ = [ ...@@ -47,7 +47,6 @@ __activations__ = [
__all__ = [ __all__ = [
'mean', 'mean',
'mul', 'mul',
'reshape',
'scale', 'scale',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'elementwise_add', 'elementwise_add',
......
...@@ -14,53 +14,88 @@ ...@@ -14,53 +14,88 @@
import unittest import unittest
import numpy as np import numpy as np
import pdb
from op_test import OpTest from op_test import OpTest
# class TestReshapeOp1(OpTest):
# def setUp(self): class TestReshapeOp(OpTest):
# ori_shape = (2, 25) def setUp(self):
# new_shape = [5, 10] ori_shape = (2, 25)
# new_shape = (5, 10)
# self.op_type = "reshape"
# self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.op_type = "reshape"
# self.attrs = {"shape": new_shape} self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
# self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} self.attrs = {"shape": new_shape, "inplace": False}
# self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
# def test_check_output(self):
# self.check_output() def test_check_output(self):
# self.check_output()
# def test_check_grad(self):
# self.check_grad(["X"], "Out") def test_check_grad(self):
# self.check_grad(["X"], "Out")
#
# class TestReshapeOpDimInfer1(OpTest):
# def setUp(self): class TestReshapeOpDimInfer1(OpTest):
# self.op_type = "reshape" def setUp(self):
# self.inputs = {"X": np.random.random((5, 10)).astype("float32")} ori_shape = (5, 10)
# self.attrs = {"shape": [5, -1, 5]} new_shape = (5, -1, 5)
# self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
# self.op_type = "reshape"
# def test_check_output(self): self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
# self.check_output() self.attrs = {"shape": new_shape, "inplace": False}
# self.outputs = {"Out": self.inputs["X"].reshape(self.attrs["shape"])}
# def test_check_grad(self):
# self.check_grad(["X"], "Out") def test_check_output(self):
self.check_output()
class TestReshapeOp2(OpTest): def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer2(OpTest):
def setUp(self):
ori_shape = (2, 2, 6)
new_shape = (2, 0, 3, -1)
infered_shape = (2, 2, 3, -1)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape, "inplace": False}
self.outputs = {"Out": self.inputs["X"].reshape(infered_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpInplace(OpTest):
def setUp(self): def setUp(self):
ori_shape = (2, 25) ori_shape = (2, 25)
new_shape = ([5, 10], ) new_shape = (5, 10)
self.op_type = "reshape"
self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
self.attrs = {"shape": new_shape}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInferInplace1(OpTest):
def setUp(self):
ori_shape = (5, 10)
new_shape = (5, -1, 5)
self.op_type = "reshape" self.op_type = "reshape"
self.inputs = { self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
"X": np.random.random(ori_shape).astype("float32"), self.attrs = {"shape": new_shape}
"Shape": np.array( self.outputs = {"Out": self.inputs["X"].reshape(new_shape)}
new_shape, dtype="int64")
}
self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -69,32 +104,23 @@ class TestReshapeOp2(OpTest): ...@@ -69,32 +104,23 @@ class TestReshapeOp2(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
# class TestReshapeOpInplace(OpTest): class TestReshapeOpDimInferInplace2(OpTest):
# def setUp(self): def setUp(self):
# self.op_type = "reshape" ori_shape = (2, 2, 6)
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")} new_shape = (2, 0, 3, -1)
# self.attrs = {'shape': [10 * 20], 'inplace': True} infered_shape = (2, 2, 3, -1)
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
# self.op_type = "reshape"
# def test_check_output(self): self.inputs = {"X": np.random.random(ori_shape).astype("float32")}
# self.check_output() self.attrs = {"shape": new_shape}
# self.outputs = {"Out": self.inputs["X"].reshape(infered_shape)}
# def test_check_grad(self):
# self.check_grad(["X"], "Out") def test_check_output(self):
# self.check_output()
#
# class TestReshapeOpDimInferInplace(OpTest): def test_check_grad(self):
# def setUp(self): self.check_grad(["X"], "Out")
# self.op_type = "reshape"
# self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
# self.attrs = {'shape': [4, -1, 5], 'inplace': True}
# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
#
# def test_check_output(self):
# self.check_output()
#
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册