From 873ee4e3802bfdf10eb86b1c8ee46aa2523e18dd Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Wed, 20 Oct 2021 14:28:47 +0800 Subject: [PATCH] adapt to cann5.0.3_alpha3. (#36106) --- cmake/external/ascend.cmake | 4 +++- .../operators/collective/c_embedding_op_npu.cc | 14 ++++++++++++++ paddle/fluid/operators/fill_constant_op_npu.cc | 10 ++++++++++ paddle/fluid/operators/lookup_table_v2_op_npu.cc | 3 +++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/cmake/external/ascend.cmake b/cmake/external/ascend.cmake index b643923cdd..03bc7784e9 100644 --- a/cmake/external/ascend.cmake +++ b/cmake/external/ascend.cmake @@ -92,6 +92,8 @@ macro(find_ascend_toolkit_version ascend_toolkit_version_info) file(READ ${ascend_toolkit_version_info} ASCEND_TOOLKIT_VERSION_CONTENTS) string(REGEX MATCH "version=([0-9]+\.[0-9]+\.[0-9]+\.[a-z]*[0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}") string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.[0-9]+\.[a-z]*[0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}") + string(REGEX REPLACE "[a-z|\.]" "" CANN_VERSION ${ASCEND_TOOLKIT_VERSION}) + add_definitions("-DCANN_VERSION_CODE=${CANN_VERSION}") if(NOT ASCEND_TOOLKIT_VERSION) set(ASCEND_TOOLKIT_VERSION "???") else() @@ -118,4 +120,4 @@ endif() find_ascend_toolkit_version(${ASCEND_TOOLKIT_DIR}/ascend_toolkit_install.info) find_ascend_driver_version(${ASCEND_DIR}/driver/version.info) -endif() \ No newline at end of file +endif() diff --git a/paddle/fluid/operators/collective/c_embedding_op_npu.cc b/paddle/fluid/operators/collective/c_embedding_op_npu.cc index c2d6072238..021e5790af 100644 --- a/paddle/fluid/operators/collective/c_embedding_op_npu.cc +++ b/paddle/fluid/operators/collective/c_embedding_op_npu.cc @@ -68,10 +68,21 @@ void shard_index(const Tensor &table_t, const Tensor &ids_t, int64_t start_idx, ignore_tensor.Resize(ids_t.dims()); NpuOpRunner sub_runner; +#if (CANN_VERSION_CODE >= 503003) + Tensor factor_tensor(ids_t.type()); + factor_tensor.mutable_data({1}, context.GetPlace()); + TensorFromVector(std::vector{static_cast(start_idx)}, + context.device_context(), &factor_tensor); + sub_runner.SetType("Sub") + .AddInput(ids_t) + .AddInput(factor_tensor) + .AddOutput(id_t); +#else sub_runner.SetType("Sub") .AddInput(ids_t) .AddInput(std::vector{static_cast(start_idx)}) .AddOutput(id_t); +#endif sub_runner.Run(); NpuOpRunner lessequal1_runner; @@ -137,6 +148,9 @@ void NPUGetIdsEmbedding(const framework::ExecutionContext &context) { .AddInput(table_t_pad) .AddInput(ids_t_local) .AddInput(std::vector{0}) +#if (CANN_VERSION_CODE >= 503003) + .AddAttrs({{"batch_dims", 0}}) +#endif .AddOutput(*output_t); runner.Run(); } diff --git a/paddle/fluid/operators/fill_constant_op_npu.cc b/paddle/fluid/operators/fill_constant_op_npu.cc index ae0148a9bf..16a2433f5c 100644 --- a/paddle/fluid/operators/fill_constant_op_npu.cc +++ b/paddle/fluid/operators/fill_constant_op_npu.cc @@ -66,11 +66,21 @@ class FillConstantNPUKernel : public framework::OpKernel { out_var->mutable_data(shape, ctx.GetPlace()); NpuOpRunner runner; +#if (CANN_VERSION_CODE >= 503003) + runner.SetType("FillD") + .AddInput(tensor_value) + .AddOutput(*out_var) + .AddAttrs( + {{ "dims", + framework::vectorize(shape) }}) + .Run(stream); +#else runner.SetType("Fill") .AddInput(framework::vectorize(shape)) .AddInput(tensor_value) .AddOutput(*out_var) .Run(stream); +#endif } }; } // namespace operators diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 387cd92b69..b75ae8a658 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -40,6 +40,9 @@ class LookupTableV2NPUKernel : public framework::OpKernel { .AddInput(*table_t) .AddInput(*ids_t) .AddInput(std::vector{0}) +#if (CANN_VERSION_CODE >= 503003) + .AddAttrs({{"batch_dims", 0}}) +#endif .AddOutput(*output_t); runner.Run(); } -- GitLab