Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
4e832b23
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4e832b23
编写于
7月 01, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support implicit type conversion for pynative mode
上级
4bdd8e16
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
269 addition
and
31 deletion
+269
-31
mindspore/ccsrc/operator/composite/do_signature.cc
mindspore/ccsrc/operator/composite/do_signature.cc
+1
-3
mindspore/ccsrc/operator/composite/do_signature.h
mindspore/ccsrc/operator/composite/do_signature.h
+2
-0
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+102
-26
tests/st/pynative/test_implicit_conversion.py
tests/st/pynative/test_implicit_conversion.py
+81
-0
tests/ut/python/pynative_mode/test_implicit_conversion.py
tests/ut/python/pynative_mode/test_implicit_conversion.py
+81
-0
tests/vm_impl/vm_me.py
tests/vm_impl/vm_me.py
+2
-2
未找到文件。
mindspore/ccsrc/operator/composite/do_signature.cc
浏览文件 @
4e832b23
...
...
@@ -31,12 +31,10 @@
namespace
mindspore
{
// namespace to support composite operators definition
namespace
prim
{
namespace
{
using
PatternListType
=
std
::
initializer_list
<
BaseRef
>
;
const
std
::
map
<
TypeId
,
size_t
>
type_map
=
{{
kNumberTypeBool
,
1
},
{
kNumberTypeInt8
,
2
},
{
kNumberTypeUInt8
,
3
},
{
kNumberTypeInt16
,
4
},
{
kNumberTypeInt32
,
5
},
{
kNumberTypeInt64
,
6
},
{
kNumberTypeFloat16
,
7
},
{
kNumberTypeFloat32
,
8
},
{
kNumberTypeFloat64
,
9
}};
namespace
{
const
std
::
vector
<
Signature
>
&
GetSignature
(
const
ValuePtr
&
function
)
{
static
const
auto
empty
=
std
::
vector
<
Signature
>
();
if
(
function
->
isa
<
Primitive
>
()
&&
function
->
cast
<
PrimitivePtr
>
()
->
has_signature
())
{
...
...
mindspore/ccsrc/operator/composite/do_signature.h
浏览文件 @
4e832b23
...
...
@@ -56,6 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph {
};
using
RWSignaturePtr
=
std
::
shared_ptr
<
DoSignatureMetaFuncGraph
>
;
extern
const
std
::
map
<
TypeId
,
size_t
>
type_map
;
AnfNodePtr
GenerateCNode
(
const
FuncGraphPtr
&
func_graph
,
const
std
::
string
&
func_name
,
const
ValuePtr
&
function
,
const
AbstractBasePtrList
&
args_spec_list
,
const
AnfNodePtrList
&
old_node_inputs
);
}
// namespace prim
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
4e832b23
...
...
@@ -160,36 +160,102 @@ std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector
return
type_indexes
;
}
std
::
map
<
SignatureEnumDType
,
size_t
>
GetDstType
(
const
py
::
tuple
&
py_args
,
std
::
map
<
SignatureEnumDType
,
TypeId
>
GetDstType
(
const
py
::
tuple
&
py_args
,
const
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
&
type_indexes
)
{
std
::
map
<
SignatureEnumDType
,
size_t
>
dst_type
;
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
;
for
(
auto
it
=
type_indexes
.
begin
();
it
!=
type_indexes
.
end
();
(
void
)
++
it
)
{
auto
type
=
it
->
first
;
auto
indexes
=
it
->
second
;
if
(
indexes
.
size
()
<
2
)
{
if
(
type
==
SignatureEnumDType
::
kDTypeEmptyDefaultValue
||
indexes
.
size
()
<
2
)
{
continue
;
}
size_t
m_index
=
indexes
[
0
];
for
(
size_t
i
=
1
;
i
<
indexes
.
size
();
++
i
)
{
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
indexes
[
i
]]))
{
m_index
=
indexes
[
i
];
size_t
priority
=
0
;
TypeId
max_type
=
TypeId
::
kTypeUnknown
;
bool
has_float
=
false
;
bool
has_int
=
false
;
for
(
size_t
index
:
indexes
)
{
if
(
!
has_float
&&
py
::
isinstance
<
py
::
float_
>
(
py_args
[
index
]))
{
has_float
=
true
;
}
if
(
!
has_int
&&
!
py
::
isinstance
<
py
::
bool_
>
(
py_args
[
index
])
&&
py
::
isinstance
<
py
::
int_
>
(
py_args
[
index
]))
{
has_int
=
true
;
}
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
index
]))
{
auto
arg
=
py
::
cast
<
tensor
::
TensorPtr
>
(
py_args
[
index
]);
TypeId
arg_type_id
=
arg
->
data_type
();
auto
type_priority
=
prim
::
type_map
.
find
(
arg_type_id
);
if
(
type_priority
->
second
>
priority
)
{
max_type
=
type_priority
->
first
;
priority
=
type_priority
->
second
;
}
}
}
if
(
max_type
==
TypeId
::
kNumberTypeBool
)
{
if
(
has_int
)
{
max_type
=
TypeId
::
kNumberTypeInt32
;
}
if
(
has_float
)
{
max_type
=
TypeId
::
kNumberTypeFloat32
;
}
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
m
_index
));
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
m
ax_type
));
}
return
dst_type
;
}
std
::
string
TypeIdToMsTypeStr
(
const
TypeId
&
type_id
)
{
switch
(
type_id
)
{
case
kNumberTypeFloat16
:
return
"float16"
;
case
kNumberTypeFloat32
:
return
"float32"
;
case
kNumberTypeFloat64
:
return
"float64"
;
case
kNumberTypeInt8
:
return
"int8"
;
case
kNumberTypeInt16
:
return
"int16"
;
case
kNumberTypeInt32
:
return
"int32"
;
case
kNumberTypeInt64
:
return
"int64"
;
case
kNumberTypeUInt8
:
return
"uint8"
;
case
kNumberTypeUInt16
:
return
"uint16"
;
case
kNumberTypeUInt32
:
return
"uint32"
;
case
kNumberTypeUInt64
:
return
"uint64"
;
case
kNumberTypeBool
:
return
"bool_"
;
default:
MS_LOG
(
EXCEPTION
)
<<
"For implicit type conversion, not support the type: "
<<
TypeIdToType
(
type_id
);
}
}
py
::
object
DoAutoCast
(
const
py
::
object
arg
,
const
TypeId
&
type_id
)
{
py
::
tuple
args
(
3
);
std
::
string
module_name
=
"mindspore.ops.functional"
;
std
::
string
op_name
=
"cast"
;
args
[
0
]
=
parse
::
python_adapter
::
GetPyFn
(
module_name
,
op_name
);
args
[
1
]
=
"Cast"
;
std
::
string
dst_type_str
=
TypeIdToMsTypeStr
(
type_id
);
module_name
=
"mindspore.common.dtype"
;
py
::
object
dst_type
=
parse
::
python_adapter
::
GetPyFn
(
module_name
,
dst_type_str
);
py
::
tuple
inputs
(
2
);
inputs
[
0
]
=
arg
;
inputs
[
1
]
=
dst_type
;
args
[
2
]
=
inputs
;
return
RunOp
(
args
)[
0
];
}
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
list
&
args
,
py
::
tuple
*
const
out_args
,
py
::
list
*
const
out_args_list
)
{
auto
&
py_args
=
*
out_args
;
py
::
tuple
input_mask
(
args
.
size
());
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
if
(
py
::
hasattr
(
args
[
i
],
"__parameter__"
))
{
input_mask
[
i
]
=
true
;
}
else
{
input_mask
[
i
]
=
false
;
}
input_mask
[
i
]
=
py
::
hasattr
(
args
[
i
],
"__parameter__"
);
py_args
[
i
]
=
GetTupleObj
(
args
[
i
]);
}
auto
signature
=
prim
->
signatures
();
...
...
@@ -197,26 +263,36 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
int
empty_dtype_count
=
std
::
count
(
dtypes
.
begin
(),
dtypes
.
end
(),
SignatureEnumDType
::
kDTypeEmptyDefaultValue
);
if
(
dtypes
.
size
()
==
0
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
if
(
dtypes
.
empty
()
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
return
input_mask
;
}
auto
type_indexes
=
GetTypeIndex
(
dtypes
);
auto
dst_type
=
GetDstType
(
py_args
,
type_indexes
);
for
(
size_t
i
=
0
;
i
<
py_args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
if
(
dtypes
[
i
]
==
SignatureEnumDType
::
kDTypeEmptyDefaultValue
)
{
continue
;
}
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
!=
dst_type
.
end
()
&&
it
->
second
!=
i
&&
(
py
::
isinstance
<
py
::
int_
>
(
py_args
[
i
])
||
py
::
isinstance
<
py
::
float_
>
(
py_args
[
i
])))
{
auto
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
py_args
[
it
->
second
]);
if
(
py
::
isinstance
<
py
::
int_
>
(
py_args
[
i
]))
{
py_args
[
i
]
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
int_
>
(
py_args
[
i
]),
tensor_ptr
->
Dtype
());
(
*
out_args_list
)[
i
]
=
py_args
[
i
];
}
else
{
double
arg_value
=
py
::
cast
<
py
::
float_
>
(
py_args
[
i
]);
py_args
[
i
]
=
std
::
make_shared
<
tensor
::
Tensor
>
(
arg_value
,
tensor_ptr
->
Dtype
());
(
*
out_args_list
)[
i
]
=
py_args
[
i
];
}
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
kTypeUnknown
)
{
continue
;
}
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
i
]))
{
auto
arg
=
py
::
cast
<
tensor
::
TensorPtr
>
(
py_args
[
i
]);
if
(
arg
->
data_type
()
==
it
->
second
)
{
continue
;
}
if
(
signature
[
i
].
rw
==
SignatureEnumRW
::
kRWWrite
)
{
MS_LOG
(
EXCEPTION
)
<<
"In op '"
<<
prim
->
name
()
<<
"',
\n
"
<<
"the type of writable argument is '"
<<
TypeIdToMsTypeStr
(
arg
->
data_type
())
<<
"', "
<<
"but the largest type in the same SignatureEumDtype is '"
<<
TypeIdToMsTypeStr
(
it
->
second
)
<<
"'. The writable arg type is not equal to the largest type, "
<<
"so can not cast automatically."
;
}
}
py
::
object
cast_output
=
DoAutoCast
(
py_args
[
i
],
it
->
second
);
(
*
out_args
)[
i
]
=
cast_output
;
(
*
out_args_list
)[
i
]
=
cast_output
;
}
return
input_mask
;
}
...
...
tests/st/pynative/test_implicit_conversion.py
0 → 100644
浏览文件 @
4e832b23
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test implicit conversion """
import
numpy
as
np
from
mindspore
import
Tensor
def
test_float_tensor_and_int_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
2
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2.1
,
2.2
,
2.3
],
[
2.4
,
2.5
,
2.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_bool_tensor_and_float_add
():
x
=
Tensor
(
np
.
array
([[
True
,
False
],
[
False
,
True
]],
dtype
=
np
.
bool_
))
y
=
3.3
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
4.3
,
3.3
],
[
3.3
,
4.3
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_bool_tensor_and_int_add
():
x
=
Tensor
(
np
.
array
([[
True
,
False
],
[
False
,
True
]],
dtype
=
np
.
bool_
))
y
=
3
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
4
,
3
],
[
3
,
4
]],
dtype
=
np
.
int32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_bool_and_int_tensor_add
():
x
=
True
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2
,
3
,
4
],
[
5
,
6
,
7
]],
dtype
=
np
.
int32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_float_tensor_and_int_tensor_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
4.4
,
5.5
,
6.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_float_tensor_and_float_tensor_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float64
))
y
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]],
dtype
=
np
.
float32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
4.4
,
5.5
,
6.6
]],
dtype
=
np
.
float64
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_int_tensor_and_int_tensor_add
():
x
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int16
))
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2
,
4
,
6
],
[
8
,
10
,
12
]],
dtype
=
np
.
int32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_float_tensor_and_bool_tensors_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
True
,
True
,
True
],
[
False
,
False
,
False
]],
dtype
=
np
.
bool_
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
1.2
,
1.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
tests/ut/python/pynative_mode/test_implicit_conversion.py
0 → 100644
浏览文件 @
4e832b23
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test implicit conversion """
import
numpy
as
np
from
mindspore
import
Tensor
def
test_float_tensor_and_int_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
2
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2.1
,
2.2
,
2.3
],
[
2.4
,
2.5
,
2.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_bool_tensor_and_float_add
():
x
=
Tensor
(
np
.
array
([[
True
,
False
],
[
False
,
True
]],
dtype
=
np
.
bool_
))
y
=
3.3
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
4.3
,
3.3
],
[
3.3
,
4.3
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_bool_tensor_and_int_add
():
x
=
Tensor
(
np
.
array
([[
True
,
False
],
[
False
,
True
]],
dtype
=
np
.
bool_
))
y
=
3
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
4
,
3
],
[
3
,
4
]],
dtype
=
np
.
int32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_bool_and_int_tensor_add
():
x
=
True
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2
,
3
,
4
],
[
5
,
6
,
7
]],
dtype
=
np
.
int32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_float_tensor_and_int_tensor_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
4.4
,
5.5
,
6.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_float_tensor_and_float_tensor_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]],
dtype
=
np
.
float16
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
2.2
,
3.3
],
[
4.4
,
5.5
,
6.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_int_tensor_and_int_tensor_add
():
x
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int8
))
y
=
Tensor
(
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
np
.
int32
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
2
,
4
,
6
],
[
8
,
10
,
12
]],
dtype
=
np
.
int32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
def
test_float_tensor_and_bool_tensors_add
():
x
=
Tensor
(
np
.
array
([[
0.1
,
0.2
,
0.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
array
([[
True
,
True
,
True
],
[
False
,
False
,
False
]],
dtype
=
np
.
bool_
))
ret_actual
=
x
+
y
ret_expect
=
Tensor
(
np
.
array
([[
1.1
,
1.2
,
1.3
],
[
0.4
,
0.5
,
0.6
]],
dtype
=
np
.
float32
))
assert
(
ret_actual
.
asnumpy
()
==
ret_expect
.
asnumpy
()).
all
()
tests/vm_impl/vm_me.py
浏览文件 @
4e832b23
...
...
@@ -403,7 +403,7 @@ def max_pool_grad(x, dout, pool_h, pool_w, stride):
"""Grad of max pooling."""
dout
=
dout
.
transpose
(
0
,
2
,
3
,
1
)
pool_size
=
pool_h
*
pool_w
dmax
=
np
.
zeros
((
dout
.
size
,
pool_size
))
dmax
=
np
.
zeros
((
dout
.
size
,
pool_size
)
,
dout
.
dtype
)
col
=
im2col
(
x
,
pool_h
,
pool_w
,
stride
)
col
=
col
.
reshape
(
-
1
,
pool_h
*
pool_w
)
arg_max
=
np
.
argmax
(
col
,
axis
=
1
)
...
...
@@ -418,7 +418,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
"""Grad of max pooling with argmax."""
dout
=
dout
.
transpose
(
0
,
2
,
3
,
1
)
pool_size
=
pool_h
*
pool_w
dmax
=
np
.
zeros
((
dout
.
size
,
pool_size
))
dmax
=
np
.
zeros
((
dout
.
size
,
pool_size
)
,
dout
.
dtype
)
dmax
[
np
.
arange
(
arg_max
.
size
),
arg_max
.
flatten
()]
=
dout
.
flatten
()
dmax
=
dmax
.
reshape
(
dout
.
shape
+
(
pool_size
,))
dcol
=
dmax
.
reshape
(
dmax
.
shape
[
0
]
*
dmax
.
shape
[
1
]
*
dmax
.
shape
[
2
],
-
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录