Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d0f719f7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d0f719f7
编写于
10月 06, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into develop
上级
eebe9b15
f8b5d54c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
210 addition
and
24 deletion
+210
-24
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/block_desc.cc
paddle/framework/block_desc.cc
+4
-0
paddle/framework/block_desc.h
paddle/framework/block_desc.h
+2
-0
paddle/framework/operator.h
paddle/framework/operator.h
+115
-16
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+3
-0
paddle/framework/tensor_array.cc
paddle/framework/tensor_array.cc
+7
-7
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+15
-0
python/paddle/v2/framework/tests/test_infer_shape.py
python/paddle/v2/framework/tests/test_infer_shape.py
+63
-0
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
d0f719f7
...
@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
...
@@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope
proto_desc
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc
)
...
...
paddle/framework/block_desc.cc
浏览文件 @
d0f719f7
...
@@ -34,6 +34,10 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
...
@@ -34,6 +34,10 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const {
return
it
->
second
.
get
();
return
it
->
second
.
get
();
}
}
bool
BlockDescBind
::
HasVar
(
const
std
::
string
&
name
)
const
{
return
vars_
.
find
(
name
)
!=
vars_
.
end
();
}
std
::
vector
<
VarDescBind
*>
BlockDescBind
::
AllVars
()
const
{
std
::
vector
<
VarDescBind
*>
BlockDescBind
::
AllVars
()
const
{
std
::
vector
<
VarDescBind
*>
res
;
std
::
vector
<
VarDescBind
*>
res
;
for
(
const
auto
&
p
:
vars_
)
{
for
(
const
auto
&
p
:
vars_
)
{
...
...
paddle/framework/block_desc.h
浏览文件 @
d0f719f7
...
@@ -43,6 +43,8 @@ class BlockDescBind {
...
@@ -43,6 +43,8 @@ class BlockDescBind {
VarDescBind
*
Var
(
const
std
::
string
&
name_bytes
)
const
;
VarDescBind
*
Var
(
const
std
::
string
&
name_bytes
)
const
;
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
std
::
vector
<
VarDescBind
*>
AllVars
()
const
;
std
::
vector
<
VarDescBind
*>
AllVars
()
const
;
BlockDescBind
*
ParentBlock
()
const
;
BlockDescBind
*
ParentBlock
()
const
;
...
...
paddle/framework/operator.h
浏览文件 @
d0f719f7
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include "op_info.h"
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor.h"
...
@@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext {
...
@@ -317,26 +318,122 @@ class ExecutionContext : public InferShapeContext {
const
platform
::
DeviceContext
&
device_context_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
CompileTimeInferShapeContext
:
public
InferShapeContextBase
{
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
input_names
=
op_
.
Input
(
name
);
auto
length
=
input_names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input(%s) should have only one value, "
"but it have %d now"
,
name
,
length
);
return
block_
.
HasVar
(
input_names
[
0
]);
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
output_names
=
op_
.
Output
(
name
);
auto
length
=
output_names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output(%s) should have only one value, "
"but it have %d now"
,
name
,
length
);
return
block_
.
HasVar
(
output_names
[
0
]);
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
input_names
=
op_
.
Input
(
name
);
PADDLE_ENFORCE
(
!
input_names
.
empty
(),
"Inputs(%s) length is 0"
,
name
);
for
(
auto
&
input
:
input_names
)
{
if
(
!
block_
.
HasVar
(
input
))
return
false
;
}
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
output_names
=
op_
.
Output
(
name
);
PADDLE_ENFORCE
(
!
output_names
.
empty
(),
"Inputs(%s) length is 0"
,
name
);
for
(
auto
&
output
:
output_names
)
{
if
(
!
block_
.
HasVar
(
output
))
return
false
;
}
return
true
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
DDim
>
ddims
=
GetInputsDim
(
name
);
auto
length
=
ddims
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input(%s) should have 1 value, "
"but it has %d now"
,
name
,
length
);
return
ddims
[
0
];
}
void
SetInputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetInputsDim
(
name
,
{
dim
});
}
DDim
GetOutputDim
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
DDim
>
ddims
=
GetOutputsDim
(
name
);
auto
length
=
ddims
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output(%s) should have 1 value, "
"but it has %d now"
,
name
,
length
);
return
ddims
[
0
];
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetOutputsDim
(
name
,
{
dim
});
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
GetAttrMap
());
}
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Input
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Output
(
name
);
}
private:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
return
framework
::
make_ddim
(
block_
.
Var
(
name
)
->
Shape
());
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
block_
.
Var
(
name
)
->
SetShape
(
framework
::
vectorize
(
dim
));
}
const
OpDescBind
&
op_
;
const
BlockDescBind
&
block_
;
};
class
RuntimeInferShapeContext
:
public
InferShapeContextBase
{
class
RuntimeInferShapeContext
:
public
InferShapeContextBase
{
public:
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
:
op_
(
op
),
scope_
(
scope
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
{
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
auto
ipt
=
op_
.
Input
(
name
);
auto
ipt
=
op_
.
Input
(
name
);
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
return
var
!=
nullptr
;
}
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
{
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
auto
ipt
=
op_
.
Output
(
name
);
auto
ipt
=
op_
.
Output
(
name
);
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
return
var
!=
nullptr
;
}
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
{
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
auto
inputs
=
op_
.
Inputs
(
name
);
auto
inputs
=
op_
.
Inputs
(
name
);
if
(
inputs
.
size
()
==
0UL
)
{
if
(
inputs
.
empty
()
)
{
return
false
;
return
false
;
}
}
for
(
auto
&
input
:
inputs
)
{
for
(
auto
&
input
:
inputs
)
{
...
@@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
...
@@ -347,9 +444,9 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return
true
;
return
true
;
}
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
{
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
auto
outputs
=
op_
.
Outputs
(
name
);
auto
outputs
=
op_
.
Outputs
(
name
);
if
(
outputs
.
size
()
==
0UL
)
{
if
(
outputs
.
empty
()
)
{
return
false
;
return
false
;
}
}
for
(
auto
&
output
:
outputs
)
{
for
(
auto
&
output
:
outputs
)
{
...
@@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
...
@@ -360,29 +457,31 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return
true
;
return
true
;
}
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
{
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
return
GetDim
(
op_
.
Input
(
name
));
return
GetDim
(
op_
.
Input
(
name
));
}
}
void
SetInputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
void
SetInputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetDim
(
op_
.
Input
(
name
),
dim
);
SetDim
(
op_
.
Input
(
name
),
dim
);
}
}
DDim
GetOutputDim
(
const
std
::
string
&
name
)
const
{
DDim
GetOutputDim
(
const
std
::
string
&
name
)
const
override
{
return
GetDim
(
op_
.
Output
(
name
));
return
GetDim
(
op_
.
Output
(
name
));
}
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetDim
(
op_
.
Output
(
name
),
dim
);
SetDim
(
op_
.
Output
(
name
),
dim
);
}
}
AttrReader
Attrs
()
const
{
return
AttrReader
(
op_
.
Attrs
());
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Inputs
(
name
);
return
op_
.
Inputs
(
name
);
}
}
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Outputs
(
name
);
return
op_
.
Outputs
(
name
);
}
}
...
@@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
...
@@ -403,11 +502,11 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
return
t
;
return
t
;
}
}
DDim
GetDim
(
const
std
::
string
&
name
)
const
{
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
return
GetTensor
<
false
>
(
name
)
->
dims
();
return
GetTensor
<
false
>
(
name
)
->
dims
();
}
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
GetTensor
<
true
>
(
name
)
->
Resize
(
dim
);
GetTensor
<
true
>
(
name
)
->
Resize
(
dim
);
}
}
...
@@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -513,9 +612,9 @@ class OperatorWithKernel : public OperatorBase {
});
});
}
}
protected:
virtual
void
InferShape
(
InferShapeContextBase
*
ctx
)
const
=
0
;
virtual
void
InferShape
(
InferShapeContextBase
*
ctx
)
const
=
0
;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
// same.
virtual
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
virtual
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
...
...
paddle/framework/shape_inference.h
浏览文件 @
d0f719f7
...
@@ -19,6 +19,9 @@ limitations under the License. */
...
@@ -19,6 +19,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into
// InferShapeContext so to replace the current InferShapeContext.
class
InferShapeContextBase
{
class
InferShapeContextBase
{
public:
public:
virtual
~
InferShapeContextBase
()
{}
virtual
~
InferShapeContextBase
()
{}
...
...
paddle/framework/tensor_array.cc
浏览文件 @
d0f719f7
...
@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
...
@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
// collect indice need to copy to the batch
// collect indice need to copy to the batch
std
::
vector
<
size_t
>
indice
;
std
::
vector
<
size_t
>
indice
;
for
(
size_t
seq_id
=
0
;
seq_id
<
meta
.
size
();
seq_id
++
)
{
for
(
const
auto
&
seq
:
meta
)
{
const
auto
&
seq_meta
=
meta
[
seq_id
]
;
size_t
id
=
seq
.
begin
+
index
;
if
(
i
ndex
>=
seq_meta
.
end
)
break
;
if
(
i
d
>=
seq
.
end
)
break
;
indice
.
push_back
(
seq_meta
.
begin
+
index
);
indice
.
push_back
(
id
);
}
}
PADDLE_ENFORCE
(
!
indice
.
empty
(),
"invalid batch at %d"
,
index
);
PADDLE_ENFORCE
(
!
indice
.
empty
(),
"invalid batch at %d"
,
index
);
// copy the indice of records in LoDTensor
// copy the indice of records in LoDTensor
...
@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
...
@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
result
.
Resize
(
make_ddim
(
record_dims_vec
));
result
.
Resize
(
make_ddim
(
record_dims_vec
));
result
.
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
result
.
mutable_data
<
value_type
>
(
platform
::
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
indice
.
size
()
-
1
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
indice
.
size
();
i
++
)
{
auto
index
=
indice
[
i
];
auto
index
=
indice
[
i
];
auto
target
=
result
.
Slice
<
value_type
>
(
i
,
i
+
1
);
auto
target
=
result
.
Slice
<
value_type
>
(
i
,
i
+
1
);
auto
source_
=
source
->
Slice
<
value_type
>
(
index
,
index
+
1
);
auto
source_
=
source
->
Slice
<
value_type
>
(
index
,
index
+
1
);
target
.
CopyFrom
<
value_type
>
(
source_
,
platform
::
CPUPlace
());
target
.
CopyFrom
<
value_type
>
(
source_
,
platform
::
CPUPlace
());
}
}
return
result
;
return
result
;
}
}
// TODO(supejom) to cache lod if reasonable
LoDTensor
PackDynamicBatch
(
const
std
::
vector
<
LoDTensor
>&
source
,
LoDTensor
PackDynamicBatch
(
const
std
::
vector
<
LoDTensor
>&
source
,
const
std
::
vector
<
DySeqMeta
>&
meta
,
const
LoD
&
lod
,
const
std
::
vector
<
DySeqMeta
>&
meta
,
const
LoD
&
lod
,
size_t
level
)
{
size_t
level
)
{
...
@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
...
@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
}
}
result
.
set_lod
(
lod
);
result
.
set_lod
(
lod
);
return
result
;
return
result
;
}
}
...
...
paddle/pybind/pybind.cc
浏览文件 @
d0f719f7
...
@@ -230,6 +230,21 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -230,6 +230,21 @@ All parameter, weight, gradient are variables in Paddle.
desc
.
InitializationErrorString
());
desc
.
InitializationErrorString
());
return
OpRegistry
::
CreateOp
(
desc
);
return
OpRegistry
::
CreateOp
(
desc
);
})
})
.
def_static
(
"infer_shape"
,
[](
OpDescBind
&
op_desc
,
BlockDescBind
&
block
)
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
.
Proto
());
auto
*
op_with_kernel
=
dynamic_cast
<
OperatorWithKernel
*>
(
op
.
get
());
if
(
op_with_kernel
!=
nullptr
)
{
auto
ctx
=
CompileTimeInferShapeContext
(
op_desc
,
block
);
op_with_kernel
->
InferShape
(
&
ctx
);
}
else
{
PADDLE_THROW
(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function"
,
op_desc
.
Type
());
}
})
.
def
(
"backward"
,
.
def
(
"backward"
,
[](
const
OperatorBase
&
forwardOp
,
[](
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
...
...
python/paddle/v2/framework/tests/test_infer_shape.py
0 → 100644
浏览文件 @
d0f719f7
import
unittest
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.op
import
Operator
class
TestInferShape
(
unittest
.
TestCase
):
def
test_sum_op
(
self
):
prog
=
core
.
ProgramDesc
.
__create_program_desc__
()
self
.
assertIsNotNone
(
prog
)
block
=
prog
.
block
(
0
)
self
.
assertIsNotNone
(
block
)
shape
=
[
10
,
20
]
# prepare input/output
x1
=
block
.
new_var
(
"x1"
)
x1
.
set_shape
(
shape
)
x2
=
block
.
new_var
(
"x2"
)
x2
.
set_shape
(
shape
)
out
=
block
.
new_var
(
"out"
)
# prepare the operator
sum_op_desc
=
block
.
append_op
()
sum_op_desc
.
set_type
(
"sum"
)
sum_op_desc
.
set_input
(
"X"
,
[
"x1"
,
"x2"
])
sum_op_desc
.
set_output
(
"Out"
,
[
"out"
])
core
.
Operator
.
infer_shape
(
sum_op_desc
,
block
)
self
.
assertEqual
(
out
.
shape
(),
shape
)
def
test_mul_op
(
self
):
prog
=
core
.
ProgramDesc
.
__create_program_desc__
()
self
.
assertIsNotNone
(
prog
)
block
=
prog
.
block
(
0
)
self
.
assertIsNotNone
(
block
)
x_shape
=
[
10
,
20
]
y_shape
=
[
20
,
30
]
# prepare input/output
x1
=
block
.
new_var
(
"x"
)
x1
.
set_shape
(
x_shape
)
x2
=
block
.
new_var
(
"y"
)
x2
.
set_shape
(
y_shape
)
out
=
block
.
new_var
(
"out"
)
# prepare the operator
mul_op_desc
=
block
.
append_op
()
mul_op_desc
.
set_type
(
"mul"
)
mul_op_desc
.
set_input
(
"X"
,
[
"x"
])
mul_op_desc
.
set_input
(
"Y"
,
[
"y"
])
mul_op_desc
.
set_output
(
"Out"
,
[
"out"
])
mul_op_desc
.
set_attr
(
"x_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"y_num_col_dims"
,
1
)
core
.
Operator
.
infer_shape
(
mul_op_desc
,
block
)
self
.
assertEqual
(
out
.
shape
(),
[
x_shape
[
0
],
y_shape
[
1
]])
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录