Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c078ed46
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c078ed46
编写于
3月 28, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enhance reshape_op by adding Input(Shape)
上级
b7e83d24
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
184 addition
and
108 deletion
+184
-108
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+21
-80
paddle/fluid/operators/reshape_op.h
paddle/fluid/operators/reshape_op.h
+104
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+37
-26
python/paddle/fluid/tests/unittests/test_reshape_op.py
python/paddle/fluid/tests/unittests/test_reshape_op.py
+22
-0
未找到文件。
paddle/fluid/operators/reshape_op.cc
浏览文件 @
c078ed46
...
@@ -17,88 +17,18 @@ limitations under the License. */
...
@@ -17,88 +17,18 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
ReshapeOp
:
public
framework
::
OperatorWithKernel
{
public:
ReshapeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE
(
!
shape
.
empty
(),
"The shape information must be set by Attr(shape)."
);
std
::
vector
<
int64_t
>
output_shape
;
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ValidateShape
(
shape
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
// NOTE: Reshape op cannot reshape an input sequence batch into an
// output sequence batch that has a different number of time steps. Here
// output always shares the LoD information with input. But if
// Attr(shape) contains 0 or -1, the actual output shape can only be
// determined during runtime. The check for wheather it is a valid
// output sequence batch is performed in runtime.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
private:
framework
::
DDim
ValidateShape
(
const
std
::
vector
<
int
>
shape
,
const
framework
::
DDim
&
in_dims
)
const
{
const
int64_t
in_size
=
framework
::
product
(
in_dims
);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const
int64_t
unk_dim_val
=
-
1
;
const
int64_t
copy_dim_val
=
0
;
std
::
vector
<
int64_t
>
output_shape
(
shape
.
size
(),
0
);
int64_t
capacity
=
1
;
int
unk_dim_idx
=
-
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
unk_dim_val
)
{
PADDLE_ENFORCE
(
unk_dim_idx
==
-
1
,
"Only one input dimension of Attr(shape) can be unknown."
);
unk_dim_idx
=
i
;
}
else
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
i
)
<
in_dims
.
size
(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape."
);
}
else
{
PADDLE_ENFORCE
(
shape
[
i
]
>
0
,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension."
);
}
capacity
*=
(
shape
[
i
]
?
shape
[
i
]
:
in_dims
[
i
]);
output_shape
[
i
]
=
(
shape
[
i
]
?
static_cast
<
int64_t
>
(
shape
[
i
])
:
in_dims
[
i
]);
}
if
(
unk_dim_idx
!=
-
1
)
{
output_shape
[
unk_dim_idx
]
=
-
in_size
/
capacity
;
PADDLE_ENFORCE_EQ
(
output_shape
[
unk_dim_idx
]
*
capacity
,
-
in_size
,
"Invalid shape is given."
);
}
else
{
PADDLE_ENFORCE_EQ
(
capacity
,
in_size
,
"Invalid shape is given."
);
}
return
framework
::
make_ddim
(
output_shape
);
}
};
class
ReshapeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
ReshapeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
ReshapeOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
ReshapeOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of reshape operator."
);
AddInput
(
"X"
,
"(Tensor). The input tensor of reshape operator."
);
AddOutput
(
"Out"
,
"The output tensor of reshape operator."
);
AddInput
(
"Shape"
,
"(Tensor<int32>, optional). If provided, reshape according to "
"this given shape. That is to say it has a higher priority than "
"the shape attribute, while the shape attribute still should be "
"set correctly to gurantee shape inference in compile time."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor). The output tensor of reshape operator."
);
AddAttr
<
std
::
vector
<
int
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"(std::vector<int>) Target shape of reshape operator."
);
"shape"
,
"(std::vector<int>) Target shape of reshape operator."
);
AddAttr
<
bool
>
(
"inplace"
,
AddAttr
<
bool
>
(
"inplace"
,
...
@@ -110,8 +40,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -110,8 +40,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Reshape Operator.
Reshape Operator.
Reshape Input(X) into the shape specified by Attr(shape)
. The data in Input(X)
Reshape Input(X) into the shape specified by Attr(shape)
or Input(Shape). The
are unchanged.
data in Input(X)
are unchanged.
Examples:
Examples:
...
@@ -141,6 +71,9 @@ Input(X) and remaining dimensions.
...
@@ -141,6 +71,9 @@ Input(X) and remaining dimensions.
dimension value will be copied from Input(X) at runtime. Note that the index of
dimension value will be copied from Input(X) at runtime. Note that the index of
0 can not exceed Rank(X). For example, Input(X) is a 3-D tensor with shape
0 can not exceed Rank(X). For example, Input(X) is a 3-D tensor with shape
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
1. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
Attr(shape) still should be set correctly to gurantee shape inference in
compile-time.
)DOC"
);
)DOC"
);
}
}
...
@@ -160,6 +93,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
...
@@ -160,6 +93,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."
);
"Input(Out@GRAD) shouldn't be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/reshape_op.h
浏览文件 @
c078ed46
...
@@ -20,15 +20,115 @@ limitations under the License. */
...
@@ -20,15 +20,115 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
ReshapeOp
:
public
framework
::
OperatorWithKernel
{
public:
ReshapeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE
(
!
shape
.
empty
(),
"The shape information must be set by Attr(shape)."
);
if
(
ctx
->
HasInput
(
"Shape"
)
&&
ctx
->
IsRuntime
())
{
// If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
return
;
}
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ValidateShape
(
shape
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
if
(
x_dims
[
0
]
==
out_dims
[
0
])
{
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
static
framework
::
DDim
ValidateShape
(
const
std
::
vector
<
int
>
shape
,
const
framework
::
DDim
&
in_dims
)
{
const
int64_t
in_size
=
framework
::
product
(
in_dims
);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const
int64_t
unk_dim_val
=
-
1
;
const
int64_t
copy_dim_val
=
0
;
std
::
vector
<
int64_t
>
output_shape
(
shape
.
size
(),
0
);
int64_t
capacity
=
1
;
int
unk_dim_idx
=
-
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
// std::cout<< shape[i] << "haha";
if
(
shape
[
i
]
==
unk_dim_val
)
{
PADDLE_ENFORCE
(
unk_dim_idx
==
-
1
,
"Only one input dimension of Attr(shape) can be unknown."
);
unk_dim_idx
=
i
;
}
else
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
i
)
<
in_dims
.
size
(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape."
);
}
else
{
PADDLE_ENFORCE
(
shape
[
i
]
>
0
,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension."
);
}
capacity
*=
(
shape
[
i
]
?
shape
[
i
]
:
in_dims
[
i
]);
output_shape
[
i
]
=
(
shape
[
i
]
?
static_cast
<
int64_t
>
(
shape
[
i
])
:
in_dims
[
i
]);
}
if
(
unk_dim_idx
!=
-
1
)
{
output_shape
[
unk_dim_idx
]
=
-
in_size
/
capacity
;
PADDLE_ENFORCE_EQ
(
output_shape
[
unk_dim_idx
]
*
capacity
,
-
in_size
,
"Invalid shape is given."
);
}
else
{
PADDLE_ENFORCE_EQ
(
capacity
,
in_size
,
"Invalid shape is given."
);
}
return
framework
::
make_ddim
(
output_shape
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ReshapeKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ReshapeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
shape_tensor
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Shape"
);
auto
out_dims
=
out
->
dims
();
framework
::
DDim
out_dims
=
out
->
dims
();
if
(
shape_tensor
)
{
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
framework
::
Tensor
cpu_shape_tensor
;
TensorCopy
(
*
shape_tensor
,
platform
::
CPUPlace
(),
ctx
.
device_context
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
out_dims
=
ReshapeOp
::
ValidateShape
(
shape
,
in
->
dims
());
}
if
(
!
in
->
lod
().
empty
())
{
if
(
!
in
->
lod
().
empty
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
out_dims
[
0
],
in
->
dims
()[
0
],
out_dims
[
0
],
in
->
dims
()[
0
],
...
@@ -39,9 +139,11 @@ class ReshapeKernel : public framework::OpKernel<T> {
...
@@ -39,9 +139,11 @@ class ReshapeKernel : public framework::OpKernel<T> {
}
}
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
if
(
!
inplace
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
TensorCopy
(
*
in
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
out
);
framework
::
TensorCopy
(
*
in
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
out
);
// TensorCopy will resize to in_dims.
out
->
Resize
(
out_dims
);
out
->
Resize
(
out_dims
);
}
else
{
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
ShareDataWith
(
*
in
);
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
c078ed46
...
@@ -3320,42 +3320,54 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
...
@@ -3320,42 +3320,54 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return
counter
return
counter
def
reshape
(
x
,
shape
,
act
=
None
,
inplace
=
True
,
name
=
None
):
def
reshape
(
x
,
shape
,
act
ual_shape
=
None
,
act
=
None
,
inplace
=
True
,
name
=
None
):
"""
"""
Gives a new shape to the input Tensor without changing its data.
Gives a new shape to the input Tensor without changing its data.
This layer takes a tensor and the attribute shape which specifies the
The target shape can be given by :attr:`shape` or :attr:`actual_shape`.
new shape as its inputs. The shape attribute must be given. It cannot be
:attr:`shape` is a list of integer while :attr:`actual_shape` is a tensor
empty. One and only one dimension of shape can be -1. More than one
variable. :attr:`actual_shape` has a higher priority than :attr:`shape`
dimension of shape can be 0.
if it is provided, while :attr:`shape` still should be set correctly to
gurantee shape inference in compile-time.
-1 means the value of this dimension is inferred from the total element
Some tricks exist when specifying the target shape.
number of x and remaining dimensions.
0 means the actual dimension value is going to be copied from the
1. -1 means the value of this dimension is inferred from the total element
corresponding dimension of x.
number of x and remaining dimensions. Thus one and only one dimension can
be set -1.
1. 0 means the actual dimension value is going to be copied from the
corresponding dimension of x. The indice of 0s in shape can not exceed
Rank(X).
Here are some examples to explain it.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified by Attr(shape) is [6, 8], the reshape operator will transform x
is [6, 8], the reshape operator will transform x into a 2-D tensor with
into a 2-D tensor with
shape [6, 8] and leaving x's data unchanged.
shape [6, 8] and leaving x's data unchanged.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified
by Attr(shape) is [2, 3, -1, 2], the reshape operator will
specified
is [2, 3, -1, 2], the reshape operator will transform x into a
transform x into a 4-D tensor with shape [2, 3, 4, 2] and leaving x's data
4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this
unchanged. In this case, one and only dimension of Attr(shape) can be set
case, one dimension of the target shape is set to -1, the value of this
to -1, the value of this dimension is inferred from the total element number
dimension is inferred from the total element number of x and remaining
of x and remaining
dimensions.
dimensions.
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
specified by Attr(shape) is [-1, 0, 3, 2], the reshape operator will
is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor
transform x into a 4-D tensor with shape [2, 4, 3, 2] and leaving x's data
with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case,
unchanged. In this case, besides -1, 0 means the actual dimension value is
besides -1, 0 means the actual dimension value is going to be copied from
going to be copied from the corresponding dimension of x during runtime
.
the corresponding dimension of x
.
Args:
Args:
input(variable): The input tensor.
input(variable): The input tensor.
shape(list): The new shape. At most one dimension of the new shape can
shape(list): The new shape. At most one dimension of the new shape can
be -1.
be -1.
actual_shape(variable): An optional input. If provided, reshape
according to this given shape rather than
:attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority
than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, a new output tensor is created
inplace(bool): If this flag is set true, a new output tensor is created
whose data is copied from input x, otherwise the output
whose data is copied from input x, otherwise the output
...
@@ -3366,12 +3378,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
...
@@ -3366,12 +3378,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
data = fluid.layers.data(
data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32'
name='data', shape=[2, 4, 6], dtype='float32')
)
reshaped = fluid.layers.reshape(
reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True)
)
"""
"""
if
not
(
isinstance
(
shape
,
list
)
or
isinstance
(
shape
,
tuple
)):
if
not
(
isinstance
(
shape
,
list
)
or
isinstance
(
shape
,
tuple
)):
...
@@ -3396,7 +3405,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
...
@@ -3396,7 +3405,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
reshaped
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
reshaped
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"reshape"
,
type
=
"reshape"
,
inputs
=
{
"X"
:
x
},
inputs
=
{
"X"
:
x
,
"Shape"
:
actual_shape
}
if
isinstance
(
actual_shape
,
Variable
)
else
{
"X"
:
x
},
attrs
=
{
"shape"
:
shape
,
attrs
=
{
"shape"
:
shape
,
"inplace"
:
inplace
},
"inplace"
:
inplace
},
outputs
=
{
"Out"
:
reshaped
})
outputs
=
{
"Out"
:
reshaped
})
...
...
python/paddle/fluid/tests/unittests/test_reshape_op.py
浏览文件 @
c078ed46
...
@@ -122,5 +122,27 @@ class TestReshapeOpDimInferInplace2(OpTest):
...
@@ -122,5 +122,27 @@ class TestReshapeOpDimInferInplace2(OpTest):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOpWithInputShape
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
6
,
5
)
new_shape
=
(
0
,
-
1
,
5
)
actual_shape
=
(
2
,
3
,
5
)
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
),
"Shape"
:
np
.
array
(
actual_shape
,
dtype
=
"int32"
)
}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
actual_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录