未验证 提交 42dd0f1b 编写于 作者: Q qipengh 提交者: GitHub

[MLU]support cast double type (#43058)

* [MLU]support cast double type

* [MLU]fix cast test
上级 0fdb3ced
......@@ -44,37 +44,7 @@ class CastMLUKernel : public framework::OpKernel<T> {
framework::DataTypeToString(src_type),
framework::DataTypeToString(dst_type)));
switch (dst_type) {
case VT::FP32:
output->mutable_data<float>(place);
break;
case VT::FP16:
output->mutable_data<paddle::platform::float16>(place);
break;
case VT::INT32:
output->mutable_data<int32_t>(place);
break;
case VT::INT16:
output->mutable_data<int16_t>(place);
break;
case VT::INT8:
output->mutable_data<int8_t>(place);
break;
case VT::UINT8:
output->mutable_data<uint8_t>(place);
break;
case VT::BOOL:
output->mutable_data<bool>(place);
break;
case VT::INT64:
output->mutable_data<int64_t>(place);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"Not supported cast %d -> %d",
framework::DataTypeToString(src_type),
framework::DataTypeToString(dst_type)));
}
output->mutable_data(place, framework::TransToPhiDataType(dst_type));
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);
......
......@@ -75,6 +75,9 @@ inline cnnlDataType_t ToCnnlDataType(
case DataType::FLOAT32:
type = CNNL_DTYPE_FLOAT;
break;
case DataType::FLOAT64:
type = CNNL_DTYPE_DOUBLE;
break;
case DataType::INT8:
type = CNNL_DTYPE_INT8;
break;
......
......@@ -61,6 +61,25 @@ class TestCastOpFp16ToFp32(OpTest):
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestCastOpFp32ToFp64(OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float64')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP64)
}
self.op_type = 'cast'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册