From 78959a39dd025d82e2dcf8f95b34d07d888e0843 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 7 Apr 2021 15:48:34 +0800 Subject: [PATCH] [NPU] fix cast op (#32121) * fix npu kernel of cast op to handle casting to same dtype * add comments --- paddle/fluid/operators/cast_op_npu.cc | 44 +++++++++++-------- .../tests/unittests/npu/test_cast_op_npu.py | 24 ++++++++++ 2 files changed, 50 insertions(+), 18 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/cast_op_npu.cc diff --git a/paddle/fluid/operators/cast_op_npu.cc b/paddle/fluid/operators/cast_op_npu.cc old mode 100755 new mode 100644 index 282ac6c1f63..0de0f5e4505 --- a/paddle/fluid/operators/cast_op_npu.cc +++ b/paddle/fluid/operators/cast_op_npu.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_ASCEND_CL #include #include @@ -41,46 +40,56 @@ class CastNPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); int dtype = ctx.Attr("out_dtype"); - auto* out = ctx.Output("Out"); - auto place = ctx.GetPlace(); - - auto iter = DTYPE_2_ACL_DTYPE.find(static_cast(dtype)); + + if (x->type() == dtype) { + // NOTE(zhiqiu): NPU cast op may result in wrong value, so + // add special case here. + VLOG(4) << "cast to same dtype:" << dtype; + out->mutable_data(place, x->type()); + framework::TensorCopy( + *x, ctx.GetPlace(), + ctx.template device_context(), out); + return; + } + + auto iter = DTYPE_2_ACL_DTYPE.find( + static_cast(dtype)); int aclDtype = iter->second; if (dtype == framework::proto::VarType::FP32) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::FP16) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::INT16) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::INT32) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::INT64) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::FP64) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::BOOL) { - out->mutable_data(place); + out->mutable_data(place); } auto stream = ctx.template device_context() .stream(); - auto runner = NpuOpRunner("Cast", {*x}, {*out}, {{"dst_type", static_cast(aclDtype)}}); + auto runner = NpuOpRunner("Cast", {*x}, {*out}, + {{"dst_type", static_cast(aclDtype)}}); runner.Run(stream); } }; } // namespace operators -} // namespace paddleaclDtype +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( - cast, - ops::CastNPUKernel, + cast, ops::CastNPUKernel, ops::CastNPUKernel, ops::CastNPUKernel, ops::CastNPUKernel, @@ -88,5 +97,4 @@ REGISTER_OP_NPU_KERNEL( ops::CastNPUKernel, ops::CastNPUKernel, ops::CastNPUKernel); -#endif + paddle::platform::float16>); diff --git a/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py index 97a10f66576..ae48866b7b9 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py @@ -50,6 +50,7 @@ class TestCast1(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) + class TestCast2(OpTest): def setUp(self): self.set_npu() @@ -71,5 +72,28 @@ class TestCast2(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) + +class TestCast3(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "cast" + self.place = paddle.NPUPlace(0) + + ipt = np.random.random(size=[10, 10]) + 1 + self.inputs = {'X': ipt.astype('int32')} + self.outputs = {'Out': ipt.astype('int32')} + + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.INT32), + 'out_dtype': int(core.VarDesc.VarType.INT32) + } + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) + + if __name__ == '__main__': unittest.main() -- GitLab