未验证 提交 b4eb413e 编写于 作者: Z zn 提交者: GitHub

[MLU]support reduce tensors on mlu (#40000)

* [MLU]support reduce tensors on mlu

* [MLU]fix compiler options
上级 0ad25fb9
...@@ -33,6 +33,7 @@ if(NOT WIN32) ...@@ -33,6 +33,7 @@ if(NOT WIN32)
endif() endif()
if(WITH_CNCL) if(WITH_CNCL)
cc_library(cncl_context SRCS cncl_context.cc DEPS collective_helper device_context tensor var_type_traits) cc_library(cncl_context SRCS cncl_context.cc DEPS collective_helper device_context tensor var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer)
endif() endif()
if(WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL) if(WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL)
cc_library(heter_ccl_context SRCS heter_ccl_context.cc DEPS collective_helper device_context tensor var_type_traits) cc_library(heter_ccl_context SRCS heter_ccl_context.cc DEPS collective_helper device_context tensor var_type_traits)
...@@ -41,7 +42,7 @@ if(NOT WIN32) ...@@ -41,7 +42,7 @@ if(NOT WIN32)
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_GLOO) if(WITH_GLOO)
cc_library(imperative_gloo_context SRCS gloo_context.cc DEPS collective_helper device_context tensor var_type_traits) cc_library(imperative_gloo_context SRCS gloo_context.cc DEPS collective_helper device_context tensor var_type_traits)
if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL) )) if ( WIN32 OR (NOT (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_ASCEND_CL OR WITH_CNCL) ))
cc_library(reducer SRCS reducer.cc DEPS layer) cc_library(reducer SRCS reducer.cc DEPS layer)
endif() endif()
endif() endif()
......
...@@ -31,7 +31,7 @@ namespace imperative { ...@@ -31,7 +31,7 @@ namespace imperative {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \ defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \
defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_CNCL)
// div the nranks // div the nranks
void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
framework::Tensor *tensor = framework::Tensor *tensor =
...@@ -67,6 +67,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { ...@@ -67,6 +67,9 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
#ifdef PADDLE_WITH_XPU_BKCL #ifdef PADDLE_WITH_XPU_BKCL
// TODO(liuyuhui) support xpu about div nranks in the future // TODO(liuyuhui) support xpu about div nranks in the future
#endif #endif
} else if (platform::is_mlu_place(tensor->place())) {
// TODO(zhangna)
VLOG(4) << "divnrank for mlu not support yet";
} }
} }
...@@ -222,6 +225,56 @@ void SplitTensorsWithType<platform::XPUDeviceContext>( ...@@ -222,6 +225,56 @@ void SplitTensorsWithType<platform::XPUDeviceContext>(
} }
#endif #endif
#ifdef PADDLE_WITH_CNCL
// context is used to select the stream for concat
template <>
void ConcatTensorsWithType<platform::MLUDeviceContext>(
const platform::MLUDeviceContext &context,
const std::vector<framework::Tensor> &dense_tensors_,
framework::Variable *p_dense_contents,
framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::FP16:
ConcatTensorsForAllReduce<platform::MLUDeviceContext, platform::float16>(
context, dense_tensors_, p_dense_contents);
break;
case framework::proto::VarType::FP32:
ConcatTensorsForAllReduce<platform::MLUDeviceContext, float>(
context, dense_tensors_, p_dense_contents);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for "
"allreduce.",
framework::DataTypeToString(type)));
}
}
// context is used to select the stream for split
template <>
void SplitTensorsWithType<platform::MLUDeviceContext>(
const platform::MLUDeviceContext &context,
framework::Variable *p_dense_contents,
std::vector<framework::Tensor> *p_dense_tensors,
framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::FP16:
SplitTensorsForAllReduce<platform::MLUDeviceContext, platform::float16>(
context, p_dense_contents, p_dense_tensors);
break;
case framework::proto::VarType::FP32:
SplitTensorsForAllReduce<platform::MLUDeviceContext, float>(
context, p_dense_contents, p_dense_tensors);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allreduce.",
framework::DataTypeToString(type)));
}
}
#endif
void Group::ConcatTensors(const platform::DeviceContext &context) { void Group::ConcatTensors(const platform::DeviceContext &context) {
auto place = context.GetPlace(); auto place = context.GetPlace();
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
...@@ -253,6 +306,16 @@ void Group::ConcatTensors(const platform::DeviceContext &context) { ...@@ -253,6 +306,16 @@ void Group::ConcatTensors(const platform::DeviceContext &context) {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat npu grads since it's not compiled with HCCL," "Paddle can't concat npu grads since it's not compiled with HCCL,"
"Please recompile or reinstall Paddle with HCCL support.")); "Please recompile or reinstall Paddle with HCCL support."));
#endif
} else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_CNCL
ConcatTensorsWithType(
static_cast<const platform::MLUDeviceContext &>(context),
dense_tensors_, &dense_contents_, dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat mlu grads since it's not compiled with CNCL,"
"Please recompile or reinstall Paddle with CNCL support."));
#endif #endif
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
ConcatTensorsWithType( ConcatTensorsWithType(
...@@ -295,6 +358,16 @@ void Group::SplitTensors(const platform::DeviceContext &context) { ...@@ -295,6 +358,16 @@ void Group::SplitTensors(const platform::DeviceContext &context) {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split npu grad since it's not compiled with HCCL," "Paddle can't split npu grad since it's not compiled with HCCL,"
"Please recompile or reinstall Paddle with HCCL support.")); "Please recompile or reinstall Paddle with HCCL support."));
#endif
} else if (platform::is_mlu_place(place)) {
#ifdef PADDLE_WITH_CNCL
SplitTensorsWithType(
static_cast<const platform::MLUDeviceContext &>(context),
&dense_contents_, &dense_tensors_, dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split mlu grad since it's not compiled with CNCL,"
"Please recompile or reinstall Paddle with CNCL support."));
#endif #endif
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
SplitTensorsWithType( SplitTensorsWithType(
...@@ -746,6 +819,11 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { ...@@ -746,6 +819,11 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
// TODO(liuyuhui) support XPU set constant // TODO(liuyuhui) support XPU set constant
VLOG(3) << "XPU doesn't support set_constant"; VLOG(3) << "XPU doesn't support set_constant";
} }
#elif defined(PADDLE_WITH_CNCL)
if (platform::is_mlu_place(group_tensor.place())) {
// TODO(liuyuhui) support MLU set constant
VLOG(3) << "MLU doesn't support set_constant";
}
#else #else
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place_);
if (HasGrad(var_index)) { if (HasGrad(var_index)) {
...@@ -847,11 +925,12 @@ void Reducer::MarkGroupReady(size_t group_index) { ...@@ -847,11 +925,12 @@ void Reducer::MarkGroupReady(size_t group_index) {
} }
}); });
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \ #elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) || \
defined(PADDLE_WITH_CNCL)
FusedAllReduceSchedule(run_order, group, next_group_); FusedAllReduceSchedule(run_order, group, next_group_);
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with BKCL or NCCL or GLOO.")); "Not compiled with BKCL or NCCL or CNCL or GLOO."));
#endif #endif
} }
} }
......
...@@ -45,7 +45,7 @@ namespace imperative { ...@@ -45,7 +45,7 @@ namespace imperative {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \ defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \
defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_CNCL)
template <typename T> template <typename T>
struct DivNRanksFunctor { struct DivNRanksFunctor {
......
...@@ -21,6 +21,6 @@ cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info s ...@@ -21,6 +21,6 @@ cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info s
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy)
cc_test(test_eager SRCS test_eager.cc DEPS tracer layer prepared_operator mul_op) cc_test(test_eager SRCS test_eager.cc DEPS tracer layer prepared_operator mul_op)
if (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL) if (WITH_NCCL OR WITH_RCCL OR WITH_XPU_BKCL OR WITH_CNCL)
cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy) cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy)
endif() endif()
...@@ -72,8 +72,10 @@ void GroupConcatSplit(Place place, size_t size) { ...@@ -72,8 +72,10 @@ void GroupConcatSplit(Place place, size_t size) {
value.push_back(static_cast<T>(1.0 * j)); value.push_back(static_cast<T>(1.0 * j));
} }
if (std::is_same<Place, platform::CUDAPlace>::value) { if (std::is_same<Place, platform::CUDAPlace>::value ||
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) std::is_same<Place, platform::MLUPlace>::value) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_CNCL)
paddle::memory::Copy(place, data, cpu_place, value.data(), paddle::memory::Copy(place, data, cpu_place, value.data(),
sizeof(T) * value.size(), 0); sizeof(T) * value.size(), 0);
#endif #endif
...@@ -180,5 +182,19 @@ TEST(TestGroup, TestXPUConcatSplit) { ...@@ -180,5 +182,19 @@ TEST(TestGroup, TestXPUConcatSplit) {
} }
#endif #endif
#if defined(PADDLE_WITH_CNCL)
TEST(TestGroup, TestMLUConcatSplit) {
platform::MLUPlace mlu_place(0);
platform::CPUPlace cpu_place;
int size = 3;
GroupConcatSplit<float>(cpu_place, size);
GroupConcatSplit<float>(mlu_place, size);
size = 15;
GroupConcatSplit<float>(cpu_place, size);
GroupConcatSplit<float>(mlu_place, size);
}
#endif
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -5,6 +5,8 @@ endif() ...@@ -5,6 +5,8 @@ endif()
# please add new math_library in alphabetical order # please add new math_library in alphabetical order
if (WITH_ASCEND_CL) if (WITH_ASCEND_CL)
math_library(concat_and_split DEPS concat_and_split_functor npu_op_runner) math_library(concat_and_split DEPS concat_and_split_functor npu_op_runner)
elseif (WITH_MLU)
math_library(concat_and_split DEPS concat_and_split_functor mlu_baseop)
else() else()
math_library(concat_and_split DEPS concat_and_split_functor) math_library(concat_and_split DEPS concat_and_split_functor)
endif() endif()
......
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif #endif
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
...@@ -226,6 +229,90 @@ class SplitFunctor<platform::NPUDeviceContext, T> { ...@@ -226,6 +229,90 @@ class SplitFunctor<platform::NPUDeviceContext, T> {
}; };
#endif #endif
#ifdef PADDLE_WITH_MLU
template <typename T>
class ConcatFunctor<platform::MLUDeviceContext, T> {
public:
void operator()(const platform::MLUDeviceContext& context,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
int dev_id = context.GetPlace().GetDeviceId();
platform::MLUDeviceGuard guard(dev_id);
auto ins_size = input.size();
const int axis_t = axis;
const int ins_size_t = ins_size;
auto place = context.GetPlace();
output->mutable_data<T>(place);
// mlu should do sth
// init ins tensors
std::vector<const void*> inputs;
std::vector<MLUCnnlTensorDesc> input_descs;
std::vector<cnnlTensorDescriptor_t> desc_vector;
for (size_t i = 0; i < ins_size; i++) {
input_descs.emplace_back(MLUCnnlTensorDesc(
input[i], CNNL_LAYOUT_ARRAY, ToCnnlDataType(input[i].dtype())));
desc_vector.push_back(input_descs.back().get());
inputs.push_back(input[i].data());
}
// init out tensors
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->dtype()));
// MLU should do sth
MLUCnnl::Concat(context, ins_size_t, axis_t, desc_vector.data(),
inputs.data(), output_desc.get(), GetBasePtr(output));
}
};
template <typename T>
class SplitFunctor<platform::MLUDeviceContext, T> {
public:
void operator()(const platform::MLUDeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
if (input.numel() == 0) {
return;
}
int dev_id = context.GetPlace().GetDeviceId();
platform::MLUDeviceGuard guard(dev_id);
auto in_dims = input.dims();
auto out_size = outputs->size();
std::vector<framework::DDim> outs_dims(out_size, in_dims);
for (size_t i = 0; i < out_size; ++i) {
outs_dims[i][axis] = ref_inputs[i]->dims()[axis];
}
// init out tensors
std::vector<void*> vct_tensor;
std::vector<MLUCnnlTensorDesc> output_descs;
std::vector<cnnlTensorDescriptor_t> desc_vector;
for (size_t i = 0; i < out_size; i++) {
(*outputs)[i]->Resize(outs_dims[i]);
(*outputs)[i]->mutable_data<T>(context.GetPlace());
output_descs.emplace_back(
MLUCnnlTensorDesc(*(*outputs)[i], CNNL_LAYOUT_ARRAY,
ToCnnlDataType((*outputs)[i]->dtype())));
desc_vector.push_back(output_descs.back().get());
vct_tensor.push_back(GetBasePtr((*outputs)[i]));
}
// init in tensors
MLUCnnlTensorDesc input_desc(input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input.dtype()));
// MLU should do sth
MLUCnnl::Split(context, out_size, axis, input_desc.get(), input.data(),
desc_vector.data(), vct_tensor.data());
}
};
#endif
#define DEFINE_FUNCTOR(type) \ #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \ template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class SplitFunctor<platform::CPUDeviceContext, type>; template class SplitFunctor<platform::CPUDeviceContext, type>;
...@@ -248,6 +335,19 @@ DEFINE_XPU_FUNCTOR(float) ...@@ -248,6 +335,19 @@ DEFINE_XPU_FUNCTOR(float)
FOR_ALL_TYPES(DEFINE_NPU_FUNCTOR) FOR_ALL_TYPES(DEFINE_NPU_FUNCTOR)
#endif #endif
#ifdef PADDLE_WITH_MLU
#define DEFINE_MLU_FUNCTOR(type) \
template class ConcatFunctor<platform::MLUDeviceContext, type>; \
template class SplitFunctor<platform::MLUDeviceContext, type>;
DEFINE_MLU_FUNCTOR(float)
DEFINE_MLU_FUNCTOR(platform::float16)
DEFINE_MLU_FUNCTOR(int64_t)
DEFINE_MLU_FUNCTOR(bool)
DEFINE_MLU_FUNCTOR(int)
DEFINE_MLU_FUNCTOR(int8_t)
DEFINE_MLU_FUNCTOR(int16_t)
DEFINE_MLU_FUNCTOR(uint8_t)
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -499,6 +499,27 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { ...@@ -499,6 +499,27 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
output_desc, output)); output_desc, output));
} }
/* static */ void MLUCnnl::Concat(const MLUDeviceContext& dev_ctx,
const int pack_num, const int axis,
const cnnlTensorDescriptor_t inputs_desc[],
const void* const inputs[],
const cnnlTensorDescriptor_t output_desc,
void* output) {
cnnlHandle_t handle = dev_ctx.cnnl_handle();
size_t workspace_size = 0;
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlGetConcatWorkspaceSize(handle, pack_num, &workspace_size));
Tensor workspace(paddle::experimental::DataType::INT8);
workspace.Resize(framework::DDim({static_cast<int64_t>(workspace_size)}));
void* workspace_ptr = workspace.mutable_data(dev_ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlConcat(handle, pack_num, axis, inputs_desc,
inputs, workspace_ptr, workspace_size,
output_desc, output));
}
/* static */ void MLUCnnl::Div( /* static */ void MLUCnnl::Div(
const ExecutionContext& ctx, cnnlComputationPreference_t prefer, const ExecutionContext& ctx, cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t in0_desc, const void* in0, const cnnlTensorDescriptor_t in0_desc, const void* in0,
...@@ -977,6 +998,27 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { ...@@ -977,6 +998,27 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
output_descs, output_ptrs)); output_descs, output_ptrs));
} }
/* static */ void MLUCnnl::Split(const MLUDeviceContext& dev_ctx, int split_num,
int axis,
const cnnlTensorDescriptor_t input_desc,
const void* input_ptr,
const cnnlTensorDescriptor_t output_descs[],
void* output_ptrs[]) {
cnnlHandle_t handle = dev_ctx.cnnl_handle();
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlGetSplitWorkspaceSize(handle, split_num, &workspace_size));
Tensor workspace(paddle::experimental::DataType::INT8);
workspace.Resize(framework::DDim({static_cast<int64_t>(workspace_size)}));
void* workspace_ptr = workspace.mutable_data(dev_ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSplit(handle, split_num, axis, input_desc,
input_ptr, workspace_ptr, workspace_size,
output_descs, output_ptrs));
}
/* static */ void MLUCnnl::GatherFunctor( /* static */ void MLUCnnl::GatherFunctor(
const ExecutionContext& ctx, const int axis, const int batch_dims, const ExecutionContext& ctx, const int axis, const int batch_dims,
const cnnlTensorDescriptor_t params_desc, const void* params, const cnnlTensorDescriptor_t params_desc, const void* params,
......
...@@ -403,6 +403,11 @@ class MLUCnnl { ...@@ -403,6 +403,11 @@ class MLUCnnl {
const void* const inputs[], const void* const inputs[],
const cnnlTensorDescriptor_t output_desc, void* output); const cnnlTensorDescriptor_t output_desc, void* output);
static void Concat(const MLUDeviceContext& dev_ctx, const int pack_num,
const int axis, const cnnlTensorDescriptor_t inputs_desc[],
const void* const inputs[],
const cnnlTensorDescriptor_t output_desc, void* output);
static void Cast(const ExecutionContext& ctx, cnnlCastDataType_t cast_type, static void Cast(const ExecutionContext& ctx, cnnlCastDataType_t cast_type,
const cnnlTensorDescriptor_t input_desc, const void* input, const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t output_desc, void* output); const cnnlTensorDescriptor_t output_desc, void* output);
...@@ -566,6 +571,12 @@ class MLUCnnl { ...@@ -566,6 +571,12 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_descs[], const cnnlTensorDescriptor_t output_descs[],
void* output_ptrs[]); void* output_ptrs[]);
static void Split(const MLUDeviceContext& dev_ctx, int split_num, int axis,
const cnnlTensorDescriptor_t input_desc,
const void* input_ptr,
const cnnlTensorDescriptor_t output_descs[],
void* output_ptrs[]);
static void Scale(const ExecutionContext& ctx, const int axis, static void Scale(const ExecutionContext& ctx, const int axis,
const cnnlTensorDescriptor_t input_desc, const void* input, const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t alpha_desc, const void* alpha, const cnnlTensorDescriptor_t alpha_desc, const void* alpha,
......
...@@ -109,6 +109,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -109,6 +109,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
auto& npu_ctx = reinterpret_cast<const platform::NPUDeviceContext&>(ctx); auto& npu_ctx = reinterpret_cast<const platform::NPUDeviceContext&>(ctx);
memory::Copy(npu_place, dst + i * dst_after, npu_place, memory::Copy(npu_place, dst + i * dst_after, npu_place,
src + i * src_after, sizeof(T) * size, npu_ctx.stream()); src + i * src_after, sizeof(T) * size, npu_ctx.stream());
#elif defined(PADDLE_WITH_MLU)
auto& mlu_place = place;
auto& mlu_ctx = reinterpret_cast<const platform::MLUDeviceContext&>(ctx);
memory::Copy(mlu_place, dst + i * dst_after, mlu_place,
src + i * src_after, sizeof(T) * size, mlu_ctx.stream());
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"Paddle is not compiled with GPU.")); "Paddle is not compiled with GPU."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册