diff --git a/paddle/fluid/operators/cast_op_npu.cc b/paddle/fluid/operators/cast_op_npu.cc old mode 100755 new mode 100644 index 282ac6c1f6306f1c453c6ec6d082345a452d6186..0de0f5e4505795f69f1d80e2bbc1600250fc7391 --- a/paddle/fluid/operators/cast_op_npu.cc +++ b/paddle/fluid/operators/cast_op_npu.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_ASCEND_CL #include #include @@ -41,46 +40,56 @@ class CastNPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); int dtype = ctx.Attr("out_dtype"); - auto* out = ctx.Output("Out"); - auto place = ctx.GetPlace(); - - auto iter = DTYPE_2_ACL_DTYPE.find(static_cast(dtype)); + + if (x->type() == dtype) { + // NOTE(zhiqiu): NPU cast op may result in wrong value, so + // add special case here. + VLOG(4) << "cast to same dtype:" << dtype; + out->mutable_data(place, x->type()); + framework::TensorCopy( + *x, ctx.GetPlace(), + ctx.template device_context(), out); + return; + } + + auto iter = DTYPE_2_ACL_DTYPE.find( + static_cast(dtype)); int aclDtype = iter->second; if (dtype == framework::proto::VarType::FP32) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::FP16) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::INT16) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::INT32) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::INT64) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::FP64) { - out->mutable_data(place); + out->mutable_data(place); } else if (dtype == framework::proto::VarType::BOOL) { - out->mutable_data(place); + out->mutable_data(place); } auto stream = ctx.template device_context() .stream(); - auto runner = NpuOpRunner("Cast", {*x}, {*out}, {{"dst_type", static_cast(aclDtype)}}); + auto runner = NpuOpRunner("Cast", {*x}, {*out}, + {{"dst_type", static_cast(aclDtype)}}); runner.Run(stream); } }; } // namespace operators -} // namespace paddleaclDtype +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( - cast, - ops::CastNPUKernel, + cast, ops::CastNPUKernel, ops::CastNPUKernel, ops::CastNPUKernel, ops::CastNPUKernel, @@ -88,5 +97,4 @@ REGISTER_OP_NPU_KERNEL( ops::CastNPUKernel, ops::CastNPUKernel, ops::CastNPUKernel); -#endif + paddle::platform::float16>); diff --git a/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py index 97a10f6657617dcfd6c57d5bc9b9ca2b142f9e7e..ae48866b7b969d5e7f2d7bf0dc9ed93c46aed4bb 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_cast_op_npu.py @@ -50,6 +50,7 @@ class TestCast1(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) + class TestCast2(OpTest): def setUp(self): self.set_npu() @@ -71,5 +72,28 @@ class TestCast2(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) + +class TestCast3(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "cast" + self.place = paddle.NPUPlace(0) + + ipt = np.random.random(size=[10, 10]) + 1 + self.inputs = {'X': ipt.astype('int32')} + self.outputs = {'Out': ipt.astype('int32')} + + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.INT32), + 'out_dtype': int(core.VarDesc.VarType.INT32) + } + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) + + if __name__ == '__main__': unittest.main()