未验证 提交 affe25b7 编写于 作者: C Chenxiao Niu 提交者: GitHub

add mlu interp_v2(nearest&bilinear). (#43383)

上级 31ddaae2
...@@ -38,7 +38,8 @@ inline std::vector<int> get_new_shape( ...@@ -38,7 +38,8 @@ inline std::vector<int> get_new_shape(
"The shape of dimension tensor should be [1]," "The shape of dimension tensor should be [1],"
"but received d%.", "but received d%.",
tensor->dims())); tensor->dims()));
if (platform::is_gpu_place(tensor->place())) { if (platform::is_gpu_place(tensor->place()) ||
platform::is_mlu_place(tensor->place())) {
framework::Tensor temp; framework::Tensor temp;
paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>())); vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
...@@ -55,7 +56,8 @@ inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) { ...@@ -55,7 +56,8 @@ inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
std::vector<T> vec_new_data; std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>(); auto* new_data = new_data_tensor->data<T>();
framework::Tensor cpu_starts_tensor; framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) { if (platform::is_gpu_place(new_data_tensor->place()) ||
platform::is_mlu_place(new_data_tensor->place())) {
paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(), paddle::framework::TensorCopySync(*new_data_tensor, platform::CPUPlace(),
&cpu_starts_tensor); &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>(); new_data = cpu_starts_tensor.data<T>();
......
此差异已折叠。
...@@ -1925,9 +1925,9 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { ...@@ -1925,9 +1925,9 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
const cnnlTensorDescriptor_t output_desc, void* output) { const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx); cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS( PADDLE_ENFORCE_MLU_SUCCESS(cnnlInterpBackward_v2(
cnnlInterpBackward(handle, align_corners, half_pixel_centers, mode, handle, align_corners, half_pixel_centers, mode, NULL, true, input_desc,
input_desc, input, output_desc, output)); input, output_desc, output));
} }
/* static */ void MLUCnnl::Cast(const ExecutionContext& ctx, /* static */ void MLUCnnl::Cast(const ExecutionContext& ctx,
......
...@@ -41,6 +41,20 @@ const std::map<std::string, cnnlReduceOp_t> MLUReduceOpMap = { ...@@ -41,6 +41,20 @@ const std::map<std::string, cnnlReduceOp_t> MLUReduceOpMap = {
{"reduce_prod", CNNL_REDUCE_MUL}, {"reduce_prod", CNNL_REDUCE_MUL},
}; };
const std::map<std::string, cnnlInterpMode_t> MLUInterpModeMap = {
{"bilinear", CNNL_INTERP_BILINEAR},
{"nearest", CNNL_INTERP_NEAREST},
{"linear", CNNL_INTERP_LINEAR},
{"trilinear", CNNL_INTERP_TRILINEAR},
{"bicubic", CNNL_INTERP_BICUBIC}};
const std::map<std::string, cnnlInterpBackwardMode_t> MLUInterpBackwardModeMap =
{{"bilinear", CNNL_INTERP_BACKWARD_BILINEAR},
{"nearest", CNNL_INTERP_BACKWARD_NEAREST},
{"linear", CNNL_INTERP_BACKWARD_LINEAR},
{"trilinear", CNNL_INTERP_BACKWARD_TRILINEAR},
{"bicubic", CNNL_INTERP_BACKWARD_BICUBIC}};
inline cnnlReduceOp_t GetMLUCnnlReduceOp(const std::string reduce_name) { inline cnnlReduceOp_t GetMLUCnnlReduceOp(const std::string reduce_name) {
auto iter = MLUReduceOpMap.find(reduce_name); auto iter = MLUReduceOpMap.find(reduce_name);
if (iter != MLUReduceOpMap.end()) { if (iter != MLUReduceOpMap.end()) {
...@@ -50,6 +64,25 @@ inline cnnlReduceOp_t GetMLUCnnlReduceOp(const std::string reduce_name) { ...@@ -50,6 +64,25 @@ inline cnnlReduceOp_t GetMLUCnnlReduceOp(const std::string reduce_name) {
"Not support reduce op type of MLU Device: %s", reduce_name)); "Not support reduce op type of MLU Device: %s", reduce_name));
} }
inline cnnlInterpMode_t GetMLUCnnlInterpMode(const std::string interp_mode) {
auto iter = MLUInterpModeMap.find(interp_mode);
if (iter != MLUInterpModeMap.end()) {
return iter->second;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"Not support interp mode of MLU Device: %s", interp_mode));
}
inline cnnlInterpBackwardMode_t GetMLUCnnlInterpBackwardMode(
const std::string interp_mode) {
auto iter = MLUInterpBackwardModeMap.find(interp_mode);
if (iter != MLUInterpBackwardModeMap.end()) {
return iter->second;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"Not support interp mode of MLU Device: %s", interp_mode));
}
inline const void* GetBasePtr(const Tensor* t) { return t->data(); } inline const void* GetBasePtr(const Tensor* t) { return t->data(); }
inline void* GetBasePtr(Tensor* t) { return t->data(); } inline void* GetBasePtr(Tensor* t) { return t->data(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册