提交 aef2c198 编写于 作者: V VectorSL

cast support more types

上级 ae50c37c
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
...@@ -75,15 +76,7 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { ...@@ -75,15 +76,7 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) {
std::string dst_type; std::string dst_type;
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0);
if (output_type == kFloat32->type_id()) { dst_type = TypeId2String(output_type);
dst_type = "float32";
} else if (output_type == kFloat16->type_id()) {
dst_type = "float16";
} else if (output_type == kInt32->type_id()) {
dst_type = "int32";
} else {
MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString();
}
AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node);
} }
......
...@@ -21,10 +21,39 @@ cast_op_info = AkgGpuRegOp("Cast") \ ...@@ -21,10 +21,39 @@ cast_op_info = AkgGpuRegOp("Cast") \
.output(0, "output") \ .output(0, "output") \
.attr("dst_type", "required", "str") \ .attr("dst_type", "required", "str") \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \ .dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \ .dtype_format(DataType.F16_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.get_op_info() .get_op_info()
......
...@@ -70,3 +70,275 @@ def test_cast1(): ...@@ -70,3 +70,275 @@ def test_cast1():
assert type0 == 'float32' assert type0 == 'float32'
type1 = output[1].asnumpy().dtype type1 = output[1].asnumpy().dtype
assert type1 == 'float32' assert type1 == 'float32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast2():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float16))
t0 = mstype.int32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float16))
t1 = mstype.float64
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int32'
type1 = output[1].asnumpy().dtype
assert type1 == 'float64'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast3():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float16))
t0 = mstype.int32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32))
t1 = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int32'
type1 = output[1].asnumpy().dtype
assert type1 == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast4():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32))
t0 = mstype.float16
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32))
t1 = mstype.int8
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float16'
type1 = output[1].asnumpy().dtype
assert type1 == 'int8'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast5():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32))
t0 = mstype.uint8
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int32))
t1 = mstype.bool_
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'uint8'
type1 = output[1].asnumpy().dtype
assert type1 == 'bool'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast6():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t0 = mstype.float64
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t1 = mstype.float32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float64'
type1 = output[1].asnumpy().dtype
assert type1 == 'float32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast7():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t0 = mstype.float32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t1 = mstype.float16
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float32'
type1 = output[1].asnumpy().dtype
assert type1 == 'float16'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast8():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t0 = mstype.int32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t1 = mstype.int16
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int32'
type1 = output[1].asnumpy().dtype
assert type1 == 'int16'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast9():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int8))
t0 = mstype.int64
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t1 = mstype.float16
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int64'
type1 = output[1].asnumpy().dtype
assert type1 == 'float16'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast10():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t0 = mstype.int8
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t1 = mstype.float64
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int8'
type1 = output[1].asnumpy().dtype
assert type1 == 'float64'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast11():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t0 = mstype.int16
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t1 = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int16'
type1 = output[1].asnumpy().dtype
assert type1 == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast12():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.bool))
t0 = mstype.int64
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.uint8))
t1 = mstype.float32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int64'
type1 = output[1].asnumpy().dtype
assert type1 == 'float32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast13():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.uint8))
t0 = mstype.int32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.uint8))
t1 = mstype.float16
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'int32'
type1 = output[1].asnumpy().dtype
assert type1 == 'float16'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast14():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t0 = mstype.float64
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t1 = mstype.float32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float64'
type1 = output[1].asnumpy().dtype
assert type1 == 'float32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast15():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t0 = mstype.float16
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t1 = mstype.int32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float16'
type1 = output[1].asnumpy().dtype
assert type1 == 'int32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast16():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t0 = mstype.float16
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int64))
t1 = mstype.float64
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float16'
type1 = output[1].asnumpy().dtype
assert type1 == 'float64'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast17():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t0 = mstype.float32
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.int16))
t1 = mstype.float16
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'float32'
type1 = output[1].asnumpy().dtype
assert type1 == 'float16'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册