未验证 提交 14cf420e 编写于 作者: Z Zhang Ting 提交者: GitHub

revert cast eigen kernel (#29445)

上级 d77566b3
...@@ -48,52 +48,17 @@ struct CastOpFunctor { ...@@ -48,52 +48,17 @@ struct CastOpFunctor {
} }
}; };
template <typename DeviceContext, typename InT, typename OutT>
static void CastFunction(const framework::ExecutionContext& context) {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto in_t = framework::EigenVector<InT>::Flatten(*in);
out->mutable_data<OutT>(context.GetPlace());
auto out_t = framework::EigenVector<OutT>::Flatten(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_t.device(place) = in_t.template cast<OutT>();
}
template <typename DeviceContext, typename InT> template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> { class CastOpKernel : public framework::OpKernel<InT> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto out_type = static_cast<framework::proto::VarType::Type>( auto* in = context.Input<framework::Tensor>("X");
context.Attr<int>("out_dtype")); auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
if (out_type == paddle::framework::proto::VarType::FP64) { static_cast<framework::proto::VarType::Type>(
CastFunction<DeviceContext, InT, double>(context); context.Attr<int>("out_dtype")),
} else if (out_type == paddle::framework::proto::VarType::FP32) { CastOpFunctor<DeviceContext, InT>(
CastFunction<DeviceContext, InT, float>(context); in, out, context.template device_context<DeviceContext>()));
} else if (out_type == paddle::framework::proto::VarType::FP16) {
CastFunction<DeviceContext, InT, paddle::platform::float16>(context);
} else if (out_type == paddle::framework::proto::VarType::INT64) {
CastFunction<DeviceContext, InT, int64_t>(context);
} else if (out_type == paddle::framework::proto::VarType::INT32) {
CastFunction<DeviceContext, InT, int>(context);
} else if (out_type == paddle::framework::proto::VarType::UINT8) {
CastFunction<DeviceContext, InT, uint8_t>(context);
} else if (out_type == paddle::framework::proto::VarType::BOOL) {
CastFunction<DeviceContext, InT, bool>(context);
} else if (out_type == paddle::framework::proto::VarType::COMPLEX64) {
CastFunction<DeviceContext, InT, paddle::platform::complex64>(context);
} else if (out_type == paddle::framework::proto::VarType::COMPLEX128) {
CastFunction<DeviceContext, InT, paddle::platform::complex128>(context);
} else {
// NOTE(chenweihang): if else branch do nothing, the output var will
// be non-initialized in dygraph, which will throw error if the
// non-initialized var is used as the next op's input
PADDLE_THROW(platform::errors::Unimplemented(
"Now does not support casting Tensor to `%s` data type.",
framework::DataTypeToString(out_type)));
}
} }
}; };
......
...@@ -90,18 +90,6 @@ class TestCastOpError(unittest.TestCase): ...@@ -90,18 +90,6 @@ class TestCastOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype_type) self.assertRaises(TypeError, test_dtype_type)
class TestCastOpErrorInDygraph(unittest.TestCase):
def test_non_support_out_dtype(self):
paddle.disable_static()
with self.assertRaises(NotImplementedError):
tensor = paddle.randn([10, 10], 'float32')
core.ops.cast(tensor, 'in_dtype', core.VarDesc.VarType.FP32,
'out_dtype', core.VarDesc.VarType.INT16)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册