From 14cf420ec2801531f506216e2e8353b05a97499f Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 8 Dec 2020 08:20:44 +0800 Subject: [PATCH] revert cast eigen kernel (#29445) --- paddle/fluid/operators/cast_op.h | 49 +++---------------- .../fluid/tests/unittests/test_cast_op.py | 12 ----- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index 91276ba6e8..8fa0416049 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -48,52 +48,17 @@ struct CastOpFunctor { } }; -template -static void CastFunction(const framework::ExecutionContext& context) { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - - auto in_t = framework::EigenVector::Flatten(*in); - out->mutable_data(context.GetPlace()); - auto out_t = framework::EigenVector::Flatten(*out); - auto& place = - *context.template device_context().eigen_device(); - out_t.device(place) = in_t.template cast(); -} - template class CastOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto out_type = static_cast( - context.Attr("out_dtype")); - - if (out_type == paddle::framework::proto::VarType::FP64) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::FP32) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::FP16) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::INT64) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::INT32) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::UINT8) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::BOOL) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::COMPLEX64) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::COMPLEX128) { - CastFunction(context); - } else { - // NOTE(chenweihang): if else branch do nothing, the output var will - // be non-initialized in dygraph, which will throw error if the - // non-initialized var is used as the next op's input - PADDLE_THROW(platform::errors::Unimplemented( - "Now does not support casting Tensor to `%s` data type.", - framework::DataTypeToString(out_type))); - } + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + framework::VisitDataType( + static_cast( + context.Attr("out_dtype")), + CastOpFunctor( + in, out, context.template device_context())); } }; diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index 44fdd8c74b..0fc3dccab4 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -90,18 +90,6 @@ class TestCastOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype_type) -class TestCastOpErrorInDygraph(unittest.TestCase): - def test_non_support_out_dtype(self): - paddle.disable_static() - - with self.assertRaises(NotImplementedError): - tensor = paddle.randn([10, 10], 'float32') - core.ops.cast(tensor, 'in_dtype', core.VarDesc.VarType.FP32, - 'out_dtype', core.VarDesc.VarType.INT16) - - paddle.enable_static() - - if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab