未验证 提交 78959a39 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] fix cast op (#32121)

* fix npu kernel of cast op to handle casting to same dtype

* add comments
上级 4638fe9a
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -41,46 +40,56 @@ class CastNPUKernel : public framework::OpKernel<T> { ...@@ -41,46 +40,56 @@ class CastNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
int dtype = ctx.Attr<int>("out_dtype"); int dtype = ctx.Attr<int>("out_dtype");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto iter = DTYPE_2_ACL_DTYPE.find(static_cast<framework::proto::VarType::Type>(dtype)); if (x->type() == dtype) {
// NOTE(zhiqiu): NPU cast op may result in wrong value, so
// add special case here.
VLOG(4) << "cast to same dtype:" << dtype;
out->mutable_data(place, x->type());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
return;
}
auto iter = DTYPE_2_ACL_DTYPE.find(
static_cast<framework::proto::VarType::Type>(dtype));
int aclDtype = iter->second; int aclDtype = iter->second;
if (dtype == framework::proto::VarType::FP32) { if (dtype == framework::proto::VarType::FP32) {
out->mutable_data<float>(place); out->mutable_data<float>(place);
} else if (dtype == framework::proto::VarType::FP16) { } else if (dtype == framework::proto::VarType::FP16) {
out->mutable_data<paddle::platform::float16>(place); out->mutable_data<paddle::platform::float16>(place);
} else if (dtype == framework::proto::VarType::INT16) { } else if (dtype == framework::proto::VarType::INT16) {
out->mutable_data<int16_t>(place); out->mutable_data<int16_t>(place);
} else if (dtype == framework::proto::VarType::INT32) { } else if (dtype == framework::proto::VarType::INT32) {
out->mutable_data<int32_t>(place); out->mutable_data<int32_t>(place);
} else if (dtype == framework::proto::VarType::INT64) { } else if (dtype == framework::proto::VarType::INT64) {
out->mutable_data<int64_t>(place); out->mutable_data<int64_t>(place);
} else if (dtype == framework::proto::VarType::FP64) { } else if (dtype == framework::proto::VarType::FP64) {
out->mutable_data<double>(place); out->mutable_data<double>(place);
} else if (dtype == framework::proto::VarType::BOOL) { } else if (dtype == framework::proto::VarType::BOOL) {
out->mutable_data<bool>(place); out->mutable_data<bool>(place);
} }
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
auto runner = NpuOpRunner("Cast", {*x}, {*out}, {{"dst_type", static_cast<int32_t>(aclDtype)}}); auto runner = NpuOpRunner("Cast", {*x}, {*out},
{{"dst_type", static_cast<int32_t>(aclDtype)}});
runner.Run(stream); runner.Run(stream);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddleaclDtype } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
cast, cast, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int>,
...@@ -88,5 +97,4 @@ REGISTER_OP_NPU_KERNEL( ...@@ -88,5 +97,4 @@ REGISTER_OP_NPU_KERNEL(
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, ops::CastNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
#endif
...@@ -50,6 +50,7 @@ class TestCast1(OpTest): ...@@ -50,6 +50,7 @@ class TestCast1(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False) self.check_output_with_place(self.place, check_dygraph=False)
class TestCast2(OpTest): class TestCast2(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
...@@ -71,5 +72,28 @@ class TestCast2(OpTest): ...@@ -71,5 +72,28 @@ class TestCast2(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3) self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)
class TestCast3(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "cast"
self.place = paddle.NPUPlace(0)
ipt = np.random.random(size=[10, 10]) + 1
self.inputs = {'X': ipt.astype('int32')}
self.outputs = {'Out': ipt.astype('int32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.INT32),
'out_dtype': int(core.VarDesc.VarType.INT32)
}
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册