未验证 提交 873ee4e3 编写于 作者: W wuhuachaocoding 提交者: GitHub

adapt to cann5.0.3_alpha3. (#36106)

上级 605e7f08
...@@ -92,6 +92,8 @@ macro(find_ascend_toolkit_version ascend_toolkit_version_info) ...@@ -92,6 +92,8 @@ macro(find_ascend_toolkit_version ascend_toolkit_version_info)
file(READ ${ascend_toolkit_version_info} ASCEND_TOOLKIT_VERSION_CONTENTS) file(READ ${ascend_toolkit_version_info} ASCEND_TOOLKIT_VERSION_CONTENTS)
string(REGEX MATCH "version=([0-9]+\.[0-9]+\.[0-9]+\.[a-z]*[0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}") string(REGEX MATCH "version=([0-9]+\.[0-9]+\.[0-9]+\.[a-z]*[0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}")
string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.[0-9]+\.[a-z]*[0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}") string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.[0-9]+\.[a-z]*[0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}")
string(REGEX REPLACE "[a-z|\.]" "" CANN_VERSION ${ASCEND_TOOLKIT_VERSION})
add_definitions("-DCANN_VERSION_CODE=${CANN_VERSION}")
if(NOT ASCEND_TOOLKIT_VERSION) if(NOT ASCEND_TOOLKIT_VERSION)
set(ASCEND_TOOLKIT_VERSION "???") set(ASCEND_TOOLKIT_VERSION "???")
else() else()
......
...@@ -68,10 +68,21 @@ void shard_index(const Tensor &table_t, const Tensor &ids_t, int64_t start_idx, ...@@ -68,10 +68,21 @@ void shard_index(const Tensor &table_t, const Tensor &ids_t, int64_t start_idx,
ignore_tensor.Resize(ids_t.dims()); ignore_tensor.Resize(ids_t.dims());
NpuOpRunner sub_runner; NpuOpRunner sub_runner;
#if (CANN_VERSION_CODE >= 503003)
Tensor factor_tensor(ids_t.type());
factor_tensor.mutable_data<T>({1}, context.GetPlace());
TensorFromVector(std::vector<T>{static_cast<T>(start_idx)},
context.device_context(), &factor_tensor);
sub_runner.SetType("Sub")
.AddInput(ids_t)
.AddInput(factor_tensor)
.AddOutput(id_t);
#else
sub_runner.SetType("Sub") sub_runner.SetType("Sub")
.AddInput(ids_t) .AddInput(ids_t)
.AddInput(std::vector<T>{static_cast<T>(start_idx)}) .AddInput(std::vector<T>{static_cast<T>(start_idx)})
.AddOutput(id_t); .AddOutput(id_t);
#endif
sub_runner.Run(); sub_runner.Run();
NpuOpRunner lessequal1_runner; NpuOpRunner lessequal1_runner;
...@@ -137,6 +148,9 @@ void NPUGetIdsEmbedding(const framework::ExecutionContext &context) { ...@@ -137,6 +148,9 @@ void NPUGetIdsEmbedding(const framework::ExecutionContext &context) {
.AddInput(table_t_pad) .AddInput(table_t_pad)
.AddInput(ids_t_local) .AddInput(ids_t_local)
.AddInput(std::vector<int32_t>{0}) .AddInput(std::vector<int32_t>{0})
#if (CANN_VERSION_CODE >= 503003)
.AddAttrs({{"batch_dims", 0}})
#endif
.AddOutput(*output_t); .AddOutput(*output_t);
runner.Run(); runner.Run();
} }
......
...@@ -66,11 +66,21 @@ class FillConstantNPUKernel : public framework::OpKernel<T> { ...@@ -66,11 +66,21 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
out_var->mutable_data<T>(shape, ctx.GetPlace()); out_var->mutable_data<T>(shape, ctx.GetPlace());
NpuOpRunner runner; NpuOpRunner runner;
#if (CANN_VERSION_CODE >= 503003)
runner.SetType("FillD")
.AddInput(tensor_value)
.AddOutput(*out_var)
.AddAttrs(
{{ "dims",
framework::vectorize(shape) }})
.Run(stream);
#else
runner.SetType("Fill") runner.SetType("Fill")
.AddInput(framework::vectorize(shape)) .AddInput(framework::vectorize(shape))
.AddInput(tensor_value) .AddInput(tensor_value)
.AddOutput(*out_var) .AddOutput(*out_var)
.Run(stream); .Run(stream);
#endif
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -40,6 +40,9 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> { ...@@ -40,6 +40,9 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
.AddInput(*table_t) .AddInput(*table_t)
.AddInput(*ids_t) .AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0}) .AddInput(std::vector<int32_t>{0})
#if (CANN_VERSION_CODE >= 503003)
.AddAttrs({{"batch_dims", 0}})
#endif
.AddOutput(*output_t); .AddOutput(*output_t);
runner.Run(); runner.Run();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册