未验证 提交 78959a39 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] fix cast op (#32121)

* fix npu kernel of cast op to handle casting to same dtype

* add comments
上级 4638fe9a
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
......@@ -41,46 +40,56 @@ class CastNPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
int dtype = ctx.Attr<int>("out_dtype");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
auto iter = DTYPE_2_ACL_DTYPE.find(static_cast<framework::proto::VarType::Type>(dtype));
if (x->type() == dtype) {
// NOTE(zhiqiu): NPU cast op may result in wrong value, so
// add special case here.
VLOG(4) << "cast to same dtype:" << dtype;
out->mutable_data(place, x->type());
framework::TensorCopy(
*x, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
return;
}
auto iter = DTYPE_2_ACL_DTYPE.find(
static_cast<framework::proto::VarType::Type>(dtype));
int aclDtype = iter->second;
if (dtype == framework::proto::VarType::FP32) {
out->mutable_data<float>(place);
out->mutable_data<float>(place);
} else if (dtype == framework::proto::VarType::FP16) {
out->mutable_data<paddle::platform::float16>(place);
out->mutable_data<paddle::platform::float16>(place);
} else if (dtype == framework::proto::VarType::INT16) {
out->mutable_data<int16_t>(place);
out->mutable_data<int16_t>(place);
} else if (dtype == framework::proto::VarType::INT32) {
out->mutable_data<int32_t>(place);
out->mutable_data<int32_t>(place);
} else if (dtype == framework::proto::VarType::INT64) {
out->mutable_data<int64_t>(place);
out->mutable_data<int64_t>(place);
} else if (dtype == framework::proto::VarType::FP64) {
out->mutable_data<double>(place);
out->mutable_data<double>(place);
} else if (dtype == framework::proto::VarType::BOOL) {
out->mutable_data<bool>(place);
out->mutable_data<bool>(place);
}
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto runner = NpuOpRunner("Cast", {*x}, {*out}, {{"dst_type", static_cast<int32_t>(aclDtype)}});
auto runner = NpuOpRunner("Cast", {*x}, {*out},
{{"dst_type", static_cast<int32_t>(aclDtype)}});
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddleaclDtype
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
cast,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
cast, ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int16_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int32_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, int>,
......@@ -88,5 +97,4 @@ REGISTER_OP_NPU_KERNEL(
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::CastNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
paddle::platform::float16>);
......@@ -50,6 +50,7 @@ class TestCast1(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
class TestCast2(OpTest):
def setUp(self):
self.set_npu()
......@@ -71,5 +72,28 @@ class TestCast2(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)
class TestCast3(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "cast"
self.place = paddle.NPUPlace(0)
ipt = np.random.random(size=[10, 10]) + 1
self.inputs = {'X': ipt.astype('int32')}
self.outputs = {'Out': ipt.astype('int32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.INT32),
'out_dtype': int(core.VarDesc.VarType.INT32)
}
def set_npu(self):
self.__class__.use_npu = True
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False, atol=1e-3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册