From 6e154fc6ef13a426fd7d57ef2f9988af21da7297 Mon Sep 17 00:00:00 2001 From: Chenxiao Niu Date: Fri, 30 Dec 2022 15:39:20 +0800 Subject: [PATCH] [MLU] cherry-pick from develop to release/2.4 (#48313) * [MLU] fix compute error of dropout op (#45923) * [MLU] add mergedAdam kernel. (#45965) * [MLU] add int64 support for mlu one_hot_v2 (#46313) * [MLU] fix profiler compile failure (#46208) * [MLU] add barrier_op kernel. (#46417) * [MLU] fluid: add mluop (#46429) * [MLU] add huber_loss kernel. (#46455) * [MLU] add mlu kernel for add_reduce_max_grad (#45651) Co-authored-by: liupeiyu * [MLU] add_fluid_mluop_yolo_box (#46573) * [MLU] fix phi::Tensor compile error of mlu. (#46649) * [MLU] add fluid MLUOps prior_box (#46585) * [MLU] fix cmake error (#46772) * [MLU]fix unittest of sync_bn (#46797) * [MLU] add masterparam support for mlu adamw. (#46804) * [MLU] add int64 support for allgather. (#46830) * [MLU] fix compile error & add mlu blacklist function. (#47439) * [MLU] fix softmax_with_cross_entropy failed in 370-X8. * [MLU] fix cncl stuck caused by multiple initializations. * [MLU] fix code style check. Co-authored-by: qipengh Co-authored-by: cifar10 <41565156+cifar10@users.noreply.github.com> Co-authored-by: Lux et Veritas <1004239791@qq.com> Co-authored-by: liupeiyu Co-authored-by: ronnywang --- cmake/neuware.cmake | 4 +- paddle/fluid/imperative/prepared_operator.cc | 48 +++ .../operators/collective/barrier_op_mlu.cc | 63 +++ .../collective/c_allgather_op_mlu.cc | 54 ++- .../fluid/operators/detection/CMakeLists.txt | 7 +- .../operators/detection/prior_box_op_mlu.cc | 104 +++++ .../operators/detection/yolo_box_op_mlu.cc | 137 +++++++ paddle/fluid/operators/dropout_op_mlu.cc | 71 ++-- paddle/fluid/operators/huber_loss_op_mlu.cc | 187 +++++++++ paddle/fluid/operators/mlu/mlu_baseop.cc | 381 +++++++++++++++++- paddle/fluid/operators/mlu/mlu_baseop.h | 187 +++++++++ paddle/fluid/operators/one_hot_v2_op_mlu.cc | 4 +- .../fluid/operators/optimizers/adam_op_mlu.cc | 275 ++++++++++++- paddle/fluid/operators/pool_op_mlu.cc | 7 +- .../operators/reduce_ops/reduce_max_op_mlu.cc | 110 +++++ .../fluid/operators/strided_slice_op_mlu.cc | 5 + .../platform/device/mlu/device_context.cc | 11 +- .../platform/device/mlu/device_context.h | 19 + paddle/fluid/platform/device/mlu/enforce.h | 10 + paddle/fluid/platform/device/mlu/mlu_info.cc | 7 + paddle/fluid/platform/device/mlu/mlu_info.h | 8 +- paddle/fluid/platform/profiler/profiler.cc | 5 + .../unittests/mlu/sync_batch_norm_op_mlu.py | 47 ++- .../mlu/test_collective_api_base_mlu.py | 108 ++--- .../unittests/mlu/test_collective_base_mlu.py | 222 +++++----- .../unittests/mlu/test_dropout_op_mlu.py | 302 +++++++------- .../unittests/mlu/test_grid_sampler_op_mlu.py | 115 +++--- .../unittests/mlu/test_huber_loss_op_mlu.py | 132 ++++++ .../unittests/mlu/test_merged_adam_op_mlu.py | 228 +++++++++++ .../unittests/mlu/test_prior_box_op_mlu.py | 214 ++++++++++ .../unittests/mlu/test_reduce_sum_op_mlu.py | 49 +-- .../tests/unittests/mlu/test_slice_op_mlu.py | 168 ++++---- .../test_softmax_with_cross_entropy_op_mlu.py | 53 +-- .../mlu/test_sync_batch_norm_base_mlu.py | 260 +++++++----- .../test_sync_batch_norm_op_mlu_baseline.py | 11 +- .../unittests/mlu/test_yolo_box_op_mlu.py | 299 ++++++++++++++ tools/dockerfile/Dockerfile.mlu | 37 +- 37 files changed, 3247 insertions(+), 702 deletions(-) create mode 100644 paddle/fluid/operators/collective/barrier_op_mlu.cc create mode 100644 paddle/fluid/operators/detection/prior_box_op_mlu.cc create mode 100644 paddle/fluid/operators/detection/yolo_box_op_mlu.cc create mode 100644 paddle/fluid/operators/huber_loss_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py diff --git a/cmake/neuware.cmake b/cmake/neuware.cmake index 16dbf16899b..8c873f35b7f 100644 --- a/cmake/neuware.cmake +++ b/cmake/neuware.cmake @@ -15,12 +15,14 @@ set(NEUWARE_LIB_DIR ${NEUWARE_HOME}/lib64) include_directories(${NEUWARE_INCLUDE_DIR}) set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so) +set(MLUOP_LIB ${NEUWARE_LIB_DIR}/libmluops.so) set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so) set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so) set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so) generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake") -set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB}) +set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${MLUOP_LIB} ${CNRT_LIB} ${CNDRV_LIB} + ${CNPAPI_LIB}) if(WITH_CNCL) message(STATUS "Compile with CNCL!") diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 62bbf77a2df..345f3af0a6d 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -146,6 +146,48 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, kernel_signature_(std::move(kernel_signature)), phi_kernel_(phi_kernel) {} +#ifdef PADDLE_WITH_MLU + +static void tokenize(const std::string& ops, + char delim, + std::unordered_set* op_set) { + std::string::size_type beg = 0; + for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos; + ++end) { + op_set->insert(ops.substr(beg, end - beg)); + beg = end + 1; + } + + op_set->insert(ops.substr(beg)); +} + +static bool is_in_mlu_black_list(const std::string& op_name) { + static bool inited = false; + static std::unordered_set mlu_black_list; + static std::mutex s_mtx; + if (!inited) { + std::lock_guard guard(s_mtx); + if (!inited) { + if (std::getenv("MLU_BLACK_LIST") != nullptr) { + std::string ops(std::getenv("MLU_BLACK_LIST")); + tokenize(ops, ',', &mlu_black_list); + } + inited = true; + VLOG(3) << "MLU Black List: "; + for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end(); + ++iter) { + VLOG(3) << *iter << " "; + } + } + } + if (mlu_black_list.find(op_name) != mlu_black_list.end()) { + return true; + } + return false; +} + +#endif + template PreparedOp PrepareImpl( const NameVarMap& ins, @@ -194,6 +236,12 @@ PreparedOp PrepareImpl( #endif +#ifdef PADDLE_WITH_MLU + if (is_in_mlu_black_list(op.Type())) { + expected_kernel_key.place_ = platform::CPUPlace(); + } +#endif + bool has_phi_kernel = false; const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type()); diff --git a/paddle/fluid/operators/collective/barrier_op_mlu.cc b/paddle/fluid/operators/collective/barrier_op_mlu.cc new file mode 100644 index 00000000000..d463e66fe62 --- /dev/null +++ b/paddle/fluid/operators/collective/barrier_op_mlu.cc @@ -0,0 +1,63 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/barrier_op.h" +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class BarrierOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_CNCL) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + + auto place = ctx.GetPlace(); + cnclDataType_t dtype = + platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype())); + int64_t numel = in->numel(); + const void* sendbuff = in->data(); + void* recvbuff = out->mutable_data(place); + + int rid = ctx.Attr("ring_id"); + auto cncl_comm = platform::CNCLCommContext::Instance().Get(rid, place); + auto* comm = cncl_comm->comm(); + auto comm_stream = cncl_comm->stream(); + auto& dev_ctx = + ctx.template device_context(); + cnclReduceOp_t cncl_red_type = cnclSum; + dev_ctx.Wait(); + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce( + sendbuff, recvbuff, numel, dtype, cncl_red_type, comm, comm_stream)); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream)); +#else + PADDLE_THROW(platform::errors::Unavailable( + "PaddlePaddle should compile with CNCL.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(barrier, ops::BarrierOpMLUKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op_mlu.cc b/paddle/fluid/operators/collective/c_allgather_op_mlu.cc index fc3ad8a006e..347349ac7a4 100644 --- a/paddle/fluid/operators/collective/c_allgather_op_mlu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op_mlu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" #if defined(PADDLE_WITH_CNCL) #include "paddle/fluid/platform/collective_helper.h" @@ -27,15 +28,14 @@ template class CAllGatherOpMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); #if defined(PADDLE_WITH_CNCL) - auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - cnclDataType_t dtype = - platform::ToCNCLDataType(framework::TransToProtoVarType(x->dtype())); + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); int nranks = ctx.Attr("nranks"); int rid = ctx.Attr("ring_id"); - auto place = ctx.GetPlace(); auto comm = platform::CNCLCommContext::Instance().Get(rid, place); PADDLE_ENFORCE_EQ( nranks, @@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel { out->mutable_data(out_dims, place); uint32_t send_numel = x->numel(); - void* send_buff = reinterpret_cast(const_cast(x->data())); - void* recv_buff = reinterpret_cast(out->data()); + void* send_buff; + void* recv_buff; + phi::DenseTensor in_tensor, out_tensor; + if (framework::TransToProtoVarType(x->dtype()) == + framework::proto::VarType::INT64) { + // cast from int64 to int32 since cncl do not support int64 + in_tensor.mutable_data(x->dims(), place); + out_tensor.mutable_data(out->dims(), place); + MLUCnnlTensorDesc x_int64_desc(*x); + MLUCnnlTensorDesc x_int32_desc(in_tensor); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT64, VT::INT32); + MLUCnnl::Cast(ctx, + cast_type, + x_int64_desc.get(), + GetBasePtr(x), + x_int32_desc.get(), + GetBasePtr(&in_tensor)); + send_buff = reinterpret_cast(in_tensor.data()); + recv_buff = reinterpret_cast(out_tensor.data()); + } else { + in_tensor.ShareDataWith(*x); + out_tensor.ShareDataWith(*out); + send_buff = reinterpret_cast(in_tensor.data()); + recv_buff = reinterpret_cast(out_tensor.data()); + } mluStream stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); } else { stream = comm->stream(); } + cnclDataType_t dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(in_tensor.dtype())); PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather( send_buff, recv_buff, send_numel, dtype, comm->comm(), stream)); + if (framework::TransToProtoVarType(x->dtype()) == + framework::proto::VarType::INT64) { + // cast back from int64 out_tensor to out + MLUCnnlTensorDesc out_int64_desc(*out); + MLUCnnlTensorDesc out_int32_desc(out_tensor); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(ctx, + cast_type, + out_int32_desc.get(), + GetBasePtr(&out_tensor), + out_int64_desc.get(), + GetBasePtr(out)); + } #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with MLU.")); @@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather, ops::CAllGatherOpMLUKernel, ops::CAllGatherOpMLUKernel, ops::CAllGatherOpMLUKernel, + ops::CAllGatherOpMLUKernel, ops::CAllGatherOpMLUKernel); diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index d965e1ace5f..81860c60492 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -42,19 +42,23 @@ if(WITH_XPU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc) detection_library(prior_box_op SRCS prior_box_op.cc) + detection_library(yolo_box_op SRCS yolo_box_op.cc) detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) elseif(WITH_MLU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_mlu.cc) - detection_library(prior_box_op SRCS prior_box_op.cc) + detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_mlu.cc) + detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op_mlu.cc) elseif(WITH_ASCEND_CL) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_npu.cc) + detection_library(yolo_box_op SRCS yolo_box_op.cc) else() detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) detection_library(prior_box_op SRCS prior_box_op.cc) + detection_library(yolo_box_op SRCS yolo_box_op.cc) # detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) endif() @@ -73,7 +77,6 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc) detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) -detection_library(yolo_box_op SRCS yolo_box_op.cc) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc diff --git a/paddle/fluid/operators/detection/prior_box_op_mlu.cc b/paddle/fluid/operators/detection/prior_box_op_mlu.cc new file mode 100644 index 00000000000..04402f6ae20 --- /dev/null +++ b/paddle/fluid/operators/detection/prior_box_op_mlu.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/detection/prior_box_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class PriorBoxMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* image = ctx.Input("Image"); + auto* boxes = ctx.Output("Boxes"); + auto* variances = ctx.Output("Variances"); + float step_w = ctx.Attr("step_w"); + float step_h = ctx.Attr("step_h"); + float offset = ctx.Attr("offset"); + bool clip = ctx.Attr("clip"); + bool min_max_aspect_ratios_order = + ctx.Attr("min_max_aspect_ratios_order"); + + int im_width = image->dims()[3]; + int im_height = image->dims()[2]; + int width = input->dims()[3]; + int height = input->dims()[2]; + + auto aspect_ratios = ctx.Attr>("aspect_ratios"); + bool flip = ctx.Attr("flip"); + std::vector new_aspect_ratios; + ExpandAspectRatios(aspect_ratios, flip, &new_aspect_ratios); + auto& dev_ctx = ctx.template device_context(); + phi::DenseTensor ratios; + paddle::framework::TensorFromVector(new_aspect_ratios, dev_ctx, &ratios); + MLUOpTensorDesc new_aspect_ratios_desc(ratios); + + auto min_sizes = ctx.Attr>("min_sizes"); + phi::DenseTensor min; + paddle::framework::TensorFromVector(min_sizes, dev_ctx, &min); + MLUOpTensorDesc min_sizes_desc(min); + + auto max_sizes = ctx.Attr>("max_sizes"); + phi::DenseTensor max; + paddle::framework::TensorFromVector(max_sizes, dev_ctx, &max); + MLUOpTensorDesc max_sizes_desc(max); + + auto variances_attr = ctx.Attr>("variances"); + phi::DenseTensor var_tensor; + paddle::framework::TensorFromVector(variances_attr, dev_ctx, &var_tensor); + MLUOpTensorDesc variances_attr_desc(var_tensor); + + auto place = ctx.GetPlace(); + + boxes->mutable_data(place); + variances->mutable_data(place); + + MLUOpTensorDesc var_desc(*variances); + MLUOpTensorDesc output_desc(*boxes); + MLUOP::OpPriorBox(ctx, + min_sizes_desc.get(), + GetBasePtr(&min), + new_aspect_ratios_desc.get(), + GetBasePtr(&ratios), + variances_attr_desc.get(), + GetBasePtr(&var_tensor), + max_sizes_desc.get(), + GetBasePtr(&max), + height, + width, + im_height, + im_width, + step_h, + step_w, + offset, + clip, + min_max_aspect_ratios_order, + output_desc.get(), + GetBasePtr(boxes), + var_desc.get(), + GetBasePtr(variances)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(prior_box, ops::PriorBoxMLUKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op_mlu.cc b/paddle/fluid/operators/detection/yolo_box_op_mlu.cc new file mode 100644 index 00000000000..739c05805d6 --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op_mlu.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { +template +class YoloBoxMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* img_size = ctx.Input("ImgSize"); + auto* boxes = ctx.Output("Boxes"); + auto* scores = ctx.Output("Scores"); + const std::vector anchors = ctx.Attr>("anchors"); + auto class_num = ctx.Attr("class_num"); + auto conf_thresh = ctx.Attr("conf_thresh"); + auto downsample_ratio = ctx.Attr("downsample_ratio"); + auto clip_bbox = ctx.Attr("clip_bbox"); + auto scale = ctx.Attr("scale_x_y"); + auto iou_aware = ctx.Attr("iou_aware"); + auto iou_aware_factor = ctx.Attr("iou_aware_factor"); + + int anchor_num = anchors.size() / 2; + int64_t size = anchors.size(); + auto dim_x = x->dims(); + int n = dim_x[0]; + int s = anchor_num; + int h = dim_x[2]; + int w = dim_x[3]; + + // The output of mluOpYoloBox: A 4-D tensor with shape [N, anchor_num, 4, + // H*W], the coordinates of boxes, and a 4-D tensor with shape [N, + // anchor_num, :attr:`class_num`, H*W], the classification scores of boxes. + std::vector boxes_dim_mluops({n, s, 4, h * w}); + std::vector scores_dim_mluops({n, s, class_num, h * w}); + + // In Paddle framework: A 3-D tensor with shape [N, M, 4], the coordinates + // of boxes, and a 3-D tensor with shape [N, M, :attr:`class_num`], the + // classification scores of boxes. + std::vector boxes_out_dim({n, s, h * w, 4}); + std::vector scores_out_dim({n, s, h * w, class_num}); + + auto& dev_ctx = ctx.template device_context(); + phi::DenseTensor boxes_tensor_mluops = + ctx.AllocateTmpTensor({n, s, 4, h * w}, dev_ctx); + phi::DenseTensor scores_tensor_mluops = + ctx.AllocateTmpTensor({n, s, class_num, h * w}, + dev_ctx); + MLUOpTensorDesc boxes_trans_desc_mluops( + 4, boxes_dim_mluops.data(), ToMluOpDataType()); + MLUCnnlTensorDesc boxes_trans_desc_cnnl( + 4, boxes_dim_mluops.data(), ToCnnlDataType()); + MLUOpTensorDesc scores_trans_desc_mluops( + 4, scores_dim_mluops.data(), ToMluOpDataType()); + MLUCnnlTensorDesc scores_trans_desc_cnnl( + 4, scores_dim_mluops.data(), ToCnnlDataType()); + + boxes->mutable_data(ctx.GetPlace()); + scores->mutable_data(ctx.GetPlace()); + FillMLUTensorWithHostValue(ctx, static_cast(0), boxes); + FillMLUTensorWithHostValue(ctx, static_cast(0), scores); + + MLUOpTensorDesc x_desc(*x, MLUOP_LAYOUT_ARRAY, ToMluOpDataType()); + MLUOpTensorDesc img_size_desc( + *img_size, MLUOP_LAYOUT_ARRAY, ToMluOpDataType()); + Tensor anchors_temp(framework::TransToPhiDataType(VT::INT32)); + anchors_temp.Resize({size}); + paddle::framework::TensorFromVector( + anchors, ctx.device_context(), &anchors_temp); + MLUOpTensorDesc anchors_desc(anchors_temp); + MLUCnnlTensorDesc boxes_desc_cnnl( + 4, boxes_out_dim.data(), ToCnnlDataType()); + MLUCnnlTensorDesc scores_desc_cnnl( + 4, scores_out_dim.data(), ToCnnlDataType()); + + MLUOP::OpYoloBox(ctx, + x_desc.get(), + GetBasePtr(x), + img_size_desc.get(), + GetBasePtr(img_size), + anchors_desc.get(), + GetBasePtr(&anchors_temp), + class_num, + conf_thresh, + downsample_ratio, + clip_bbox, + scale, + iou_aware, + iou_aware_factor, + boxes_trans_desc_mluops.get(), + GetBasePtr(&boxes_tensor_mluops), + scores_trans_desc_mluops.get(), + GetBasePtr(&scores_tensor_mluops)); + const std::vector perm = {0, 1, 3, 2}; + + // transpose the boxes from [N, S, 4, H*W] to [N, S, H*W, 4] + MLUCnnl::Transpose(ctx, + perm, + 4, + boxes_trans_desc_cnnl.get(), + GetBasePtr(&boxes_tensor_mluops), + boxes_desc_cnnl.get(), + GetBasePtr(boxes)); + + // transpose the scores from [N, S, class_num, H*W] to [N, S, H*W, + // class_num] + MLUCnnl::Transpose(ctx, + perm, + 4, + scores_trans_desc_cnnl.get(), + GetBasePtr(&scores_tensor_mluops), + scores_desc_cnnl.get(), + GetBasePtr(scores)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(yolo_box, ops::YoloBoxMLUKernel); diff --git a/paddle/fluid/operators/dropout_op_mlu.cc b/paddle/fluid/operators/dropout_op_mlu.cc index 923e6cc5ed9..142e047e6c2 100644 --- a/paddle/fluid/operators/dropout_op_mlu.cc +++ b/paddle/fluid/operators/dropout_op_mlu.cc @@ -39,8 +39,17 @@ class DropoutMLUKernel : public framework::OpKernel { MLUCnnlTensorDesc x_desc(*x); MLUCnnlTensorDesc out_desc(*out); - if (!is_test) { - // exec dropout op for training only. + if (is_test && is_upscale) { + // dropout op for inference: out = input. + framework::TensorCopy( + *x, + ctx.GetPlace(), + ctx.template device_context(), + out); + return; + } else if (!is_test) { + // dropout op for training: out = input * mask / ( 1.0 - dropout_prob ) or + // out = input * mask. int seed_data = 0; if (seed_tensor) { if (platform::is_mlu_place(seed_tensor->place())) { @@ -79,50 +88,44 @@ class DropoutMLUKernel : public framework::OpKernel { const int device_id = ctx.GetPlace().GetDeviceId(); auto mlu_gen_random = GetMLURandomGenerator(ctx, device_id, seed_data); - const float prob = is_upscale ? dropout_prob : 0.0f; + // compute out = input * mask / ( 1.0 - dropout_prob ) MLUCnnl::FusedDropout(ctx, mlu_gen_random->get(), x_desc.get(), GetBasePtr(x), - prob, + dropout_prob, GetBasePtr(&(mlu_gen_random->get_state())), mask_desc.get(), GetBasePtr(mask), out_desc.get(), GetBasePtr(out)); - } else { - // exec dropout op for inference only. + if (is_upscale) { - framework::TensorCopy( - *x, - ctx.GetPlace(), - ctx.template device_context(), - out); - } else { - auto scale = static_cast(1.0f - dropout_prob); - Tensor scale_tensor(x->dtype()); - scale_tensor.mutable_data({1}, ctx.GetPlace()); - MLUCnnlTensorDesc scale_desc(scale_tensor); - MLUCnnl::Fill(ctx, - CNNL_POINTER_MODE_HOST, - &scale, - scale_desc.get(), - GetBasePtr(&scale_tensor)); - - auto data_type = ToCnnlDataType(); - MLUCnnlOpTensorDesc op_tensor_desc( - CNNL_OP_TENSOR_MUL, data_type, CNNL_NOT_PROPAGATE_NAN); - MLUCnnl::OpTensor(ctx, - op_tensor_desc.get(), - x_desc.get(), - GetBasePtr(x), - scale_desc.get(), - GetBasePtr(&scale_tensor), - out_desc.get(), - GetBasePtr(out), - data_type); + return; } } + + // In downgrade_in_infer mode, need to multiply (1.0f - dropout_prob). + Tensor scale_tensor(x->dtype()); + Tensor bias_tensor(x->dtype()); + scale_tensor.mutable_data({1}, ctx.GetPlace()); + bias_tensor.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc bias_desc(bias_tensor); + FillMLUTensorWithHostValue( + ctx, static_cast(1.0f - dropout_prob), &scale_tensor); + FillMLUTensorWithHostValue(ctx, static_cast(0.0f), &bias_tensor); + + MLUCnnl::Scale(ctx, + 0, + is_test ? x_desc.get() : out_desc.get(), + is_test ? GetBasePtr(x) : GetBasePtr(out), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(out)); } }; diff --git a/paddle/fluid/operators/huber_loss_op_mlu.cc b/paddle/fluid/operators/huber_loss_op_mlu.cc new file mode 100644 index 00000000000..4387037ad01 --- /dev/null +++ b/paddle/fluid/operators/huber_loss_op_mlu.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +template +class HuberLossMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = GetDevCtxFromCTX(ctx); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* residual = ctx.Output("Residual"); + auto* out = ctx.Output("Out"); + auto delta = ctx.Attr("delta"); + + auto place = ctx.GetPlace(); + + // compute y-x + cnnlDataType_t data_type = ToCnnlDataType(); + residual->mutable_data(x->dims(), place); + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlOpTensorDesc sub_op_desc( + CNNL_OP_TENSOR_SUB, data_type, CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, + sub_op_desc.get(), + x_desc.get(), + GetBasePtr(y), + x_desc.get(), + GetBasePtr(x), + x_desc.get(), + GetBasePtr(residual), + data_type); + + // compute smoothl1loss + out->mutable_data(x->dims(), place); + cnnlSmoothL1LossAlgorithm_t smoothl1_algo = + CNNL_SMOOTHL1LOSS_REDUCTION_NONE; // defines whether to do reduction + // here + MLUCnnl::SmoothL1LossForward(ctx, + x_desc.get(), + GetBasePtr(x), + x_desc.get(), /* target has same shape as x */ + GetBasePtr(y), + static_cast(delta), + smoothl1_algo, + x_desc.get(), /* out has same shape as x */ + GetBasePtr(out)); + + // compute multiply by delta + Tensor scale_tensor, bias_tensor; + scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + bias_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + FillMLUTensorWithHostValue(ctx, static_cast(delta), &scale_tensor); + FillMLUTensorWithHostValue(ctx, static_cast(0.f), &bias_tensor); + const int axis = std::max(out->dims().size() - 1, 0); + + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc bias_desc(bias_tensor); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::Scale(ctx, + axis, + out_desc.get(), + GetBasePtr(out), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(out)); + } +}; + +template +class HuberLossGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = GetDevCtxFromCTX(ctx); + auto* residual = ctx.Input("Residual"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto delta = ctx.Attr("delta"); + + auto place = ctx.GetPlace(); + + Tensor t_grad_rd; + t_grad_rd = + ctx.AllocateTmpTensor(residual->dims(), dev_ctx); + MLUCnnlTensorDesc t_grad_rd_desc(t_grad_rd); + if (dx || dy) { + Tensor t_zero; + t_zero = + ctx.AllocateTmpTensor(residual->dims(), dev_ctx); + FillMLUTensorWithHostValue(ctx, static_cast(0.f), &t_zero); + + MLUCnnlTensorDesc residual_desc(*residual); + MLUCnnlTensorDesc dout_desc(*dout); + + cnnlSmoothL1LossAlgorithm_t smoothl1_algo = + CNNL_SMOOTHL1LOSS_REDUCTION_NONE; // defines whether to do reduction + // here + MLUCnnl::SmoothL1LossBackward(ctx, + residual_desc.get(), + GetBasePtr(residual), + residual_desc.get(), + GetBasePtr(&t_zero), + dout_desc.get(), + GetBasePtr(dout), + static_cast(delta), + smoothl1_algo, + t_grad_rd_desc.get(), + GetBasePtr(&t_grad_rd)); + } + // compute multiply by delta + Tensor scale_tensor, bias_tensor; + scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + bias_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + + FillMLUTensorWithHostValue(ctx, static_cast(0.f), &bias_tensor); + const int axis = std::max(t_grad_rd.dims().size() - 1, 0); + + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc bias_desc(bias_tensor); + + if (dx) { + dx->mutable_data(place); + FillMLUTensorWithHostValue(ctx, static_cast(-delta), &scale_tensor); + MLUCnnlTensorDesc out_desc(*dx); + MLUCnnl::Scale(ctx, + axis, + t_grad_rd_desc.get(), + GetBasePtr(&t_grad_rd), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(dx)); + } + if (dy) { + dy->mutable_data(place); + FillMLUTensorWithHostValue(ctx, static_cast(delta), &scale_tensor); + MLUCnnlTensorDesc out_desc(*dy); + MLUCnnl::Scale(ctx, + axis, + t_grad_rd_desc.get(), + GetBasePtr(&t_grad_rd), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(dy)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(huber_loss, + ops::HuberLossMLUKernel, + ops::HuberLossMLUKernel); +REGISTER_OP_MLU_KERNEL(huber_loss_grad, + ops::HuberLossGradMLUKernel, + ops::HuberLossGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 4cd754775d9..a09d79e8d08 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -256,6 +256,186 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() { } } +class MLUOpTensorDescPool { + public: + mluOpTensorDescriptor_t Pop() { + mluOpTensorDescriptor_t raw_desc; + if (q_.try_dequeue(raw_desc)) { + return raw_desc; + } else { + mluOpCreateTensorDescriptor(&raw_desc); + return raw_desc; + } + } + + void Recycle(mluOpTensorDescriptor_t desc) { + mluOpResetTensorDescriptor(desc); + q_.enqueue(desc); + } + + ~MLUOpTensorDescPool() { + auto size = q_.size_approx(); + if (size > 0) { + std::vector vec(size); + q_.try_dequeue_bulk(vec.data(), size); + for (auto desc : vec) { + mluOpDestroyTensorDescriptor(desc); + } + } + } + + private: + moodycamel::ConcurrentQueue q_; +}; + +static MLUOpTensorDescPool g_mluop_tensor_desc_pool; + +MLUOpTensorDesc& MLUOpTensorDesc::operator=(MLUOpTensorDesc&& rhs) { + if (raw_tensor_desc) { + g_mluop_tensor_desc_pool.Recycle(raw_tensor_desc); + } + raw_tensor_desc = rhs.raw_tensor_desc; + rhs.raw_tensor_desc = nullptr; + return *this; +} + +MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim, + const int dim_sizes[], + const mluOpDataType_t tensor_dtype) { + raw_tensor_desc = g_mluop_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc, + MLUOP_LAYOUT_ARRAY, + tensor_dtype, + tensor_dim, + dim_sizes)); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim, + const int dim_sizes[], + const mluOpDataType_t tensor_dtype, + const mluOpTensorLayout_t layout) { + raw_tensor_desc = g_mluop_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor( + raw_tensor_desc, layout, tensor_dtype, tensor_dim, dim_sizes)); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim, + const int dim_sizes[], + const mluOpDataType_t tensor_dtype, + int position) + : MLUOpTensorDesc(tensor_dim, dim_sizes, tensor_dtype) { + PADDLE_ENFORCE_MLU_SUCCESS( + mluOpSetTensorDescriptorPosition(raw_tensor_desc, position)); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const mluOpDataType_t tensor_dtype) { + std::vector dim_sizes_int32(tensor_dim); + std::vector::const_iterator int64_cbegin(dim_sizes); + std::vector::const_iterator int64_cend(dim_sizes + tensor_dim); + std::transform(int64_cbegin, + int64_cend, + dim_sizes_int32.begin(), + &CheckedNarrowing); + raw_tensor_desc = g_mluop_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc, + MLUOP_LAYOUT_ARRAY, + tensor_dtype, + tensor_dim, + dim_sizes_int32.data())); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const mluOpDataType_t tensor_dtype, + const mluOpTensorLayout_t layout) { + std::vector dim_sizes_int32(tensor_dim); + std::vector::const_iterator int64_cbegin(dim_sizes); + std::vector::const_iterator int64_cend(dim_sizes + tensor_dim); + std::transform(int64_cbegin, + int64_cend, + dim_sizes_int32.begin(), + &CheckedNarrowing); + raw_tensor_desc = g_mluop_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc, + layout, + tensor_dtype, + tensor_dim, + dim_sizes_int32.data())); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const mluOpDataType_t tensor_dtype, + int position) { + std::vector dim_sizes_int32(tensor_dim); + std::vector::const_iterator int64_cbegin(dim_sizes); + std::vector::const_iterator int64_cend(dim_sizes + tensor_dim); + std::transform(int64_cbegin, + int64_cend, + dim_sizes_int32.begin(), + &CheckedNarrowing); + raw_tensor_desc = g_mluop_tensor_desc_pool.Pop(); + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc, + MLUOP_LAYOUT_ARRAY, + tensor_dtype, + tensor_dim, + dim_sizes_int32.data())); + PADDLE_ENFORCE_MLU_SUCCESS( + mluOpSetTensorDescriptorPosition(raw_tensor_desc, position)); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor, + const mluOpTensorLayout_t layout, + const mluOpDataType_t tensor_dtype) { + auto dims = phi::vectorize(tensor.dims()); + int tensor_dim = dims.size(); + raw_tensor_desc = g_mluop_tensor_desc_pool.Pop(); + if (tensor_dim == 0) { + int scalar_dims[1] = {1}; + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor( + raw_tensor_desc, layout, tensor_dtype, 1, scalar_dims)); + } else { + std::vector tensor_dim_sizes_int(dims.begin(), dims.end()); + PADDLE_ENFORCE_MLU_SUCCESS( + mluOpSetTensorDescriptor(raw_tensor_desc, + layout, + tensor_dtype, + tensor_dim, + tensor_dim_sizes_int.data())); + } +} + +MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor) + : MLUOpTensorDesc( + tensor, MLUOP_LAYOUT_ARRAY, ToMluOpDataType(tensor.dtype())) {} + +MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor, + mluOpTensorLayout_t layout, + const mluOpDataType_t tensor_dtype, + int position) + : MLUOpTensorDesc(tensor, layout, tensor_dtype) { + PADDLE_ENFORCE_MLU_SUCCESS( + mluOpSetTensorDescriptorPosition(raw_tensor_desc, position)); +} + +MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor, + mluOpTensorLayout_t layout, + const mluOpDataType_t tensor_dtype, + int position, + float scale) + : MLUOpTensorDesc(tensor, layout, tensor_dtype) { + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptorPositionAndScale( + raw_tensor_desc, position, scale)); +} + +MLUOpTensorDesc::~MLUOpTensorDesc() { + if (raw_tensor_desc) { + g_mluop_tensor_desc_pool.Recycle(raw_tensor_desc); + } +} + MLUCnnlActivationDesc::MLUCnnlActivationDesc( const cnnlActivationMode_t act_mode, const float ceof) { PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); @@ -1563,17 +1743,35 @@ MLURNNDesc::~MLURNNDesc() { void* indices_out) { cnnlHandle_t handle = GetHandleFromCTX(ctx); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlTopKTensor(handle, - input_desc, - input, - k, - dim, - largest, - sorted, - values_output_desc, - values_out, - indices_output_desc, - indices_out)); + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetTopKTensorWorkspaceSize(handle, + input_desc, + k, + dim, + largest, + values_output_desc, + indices_output_desc, + &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlTopKTensor_v3(handle, + input_desc, + input, + k, + dim, + largest, + sorted, + false /*lower_index_first*/, + workspace_ptr, + workspace_size, + values_output_desc, + values_out, + indices_output_desc, + indices_out)); } /* static */ void MLUCnnl::StridedSlice( @@ -4527,6 +4725,78 @@ MLURNNDesc::~MLURNNDesc() { output)); } +/* static */ void MLUCnnl::SmoothL1LossForward( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t t_desc, + const void* target, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t y_desc, + void* y) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetSmoothL1LossForwardWorkspaceSize( + handle, x_desc, algorithm, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSmoothL1LossForward_v2(handle, + x_desc, + x, + t_desc, + target, + beta, + algorithm, + workspace_ptr, + workspace_size, + y_desc, + y)); +} + +/* static */ void MLUCnnl::SmoothL1LossBackward( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t target_desc, + const void* target, + const cnnlTensorDescriptor_t dy_desc, + const void* dy, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t dx_desc, + void* dx) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetSmoothL1LossBackwardWorkspaceSize( + handle, x_desc, algorithm, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSmoothL1LossBackward_v2(handle, + x_desc, + x, + target_desc, + target, + dy_desc, + dy, + beta, + algorithm, + workspace_ptr, + workspace_size, + dx_desc, + dx)); +} + /* static */ void MLUCnnl::EmbeddingForward( const ExecutionContext& ctx, const int padding_idx, @@ -5148,5 +5418,94 @@ MLURNNDesc::~MLURNNDesc() { diff_x)); } +/* static */ void MLUOP::OpYoloBox(const ExecutionContext& ctx, + const mluOpTensorDescriptor_t x_desc, + const void* x, + const mluOpTensorDescriptor_t img_size_desc, + const void* img_size, + const mluOpTensorDescriptor_t anchors_desc, + const void* anchors, + const int class_num, + const float conf_thresh, + const int downsample_ratio, + const bool clip_bbox, + const float scale, + const bool iou_aware, + const float iou_aware_factor, + const mluOpTensorDescriptor_t boxes_desc, + void* boxes, + const mluOpTensorDescriptor_t scores_desc, + void* scores) { + mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(mluOpYoloBox(handle, + x_desc, + x, + img_size_desc, + img_size, + anchors_desc, + anchors, + class_num, + conf_thresh, + downsample_ratio, + clip_bbox, + scale, + iou_aware, + iou_aware_factor, + boxes_desc, + boxes, + scores_desc, + scores)); +} + +/* static */ void MLUOP::OpPriorBox( + const ExecutionContext& ctx, + const mluOpTensorDescriptor_t min_sizes_desc, + const void* min_sizes, + const mluOpTensorDescriptor_t aspect_ratios_desc, + const void* aspect_ratios, + const mluOpTensorDescriptor_t variances_desc, + const void* variances, + const mluOpTensorDescriptor_t max_sizes_desc, + const void* max_sizes, + const int height, + const int width, + const int im_height, + const int im_width, + const float step_h, + const float step_w, + const float offset, + const bool clip, + const bool min_max_aspect_ratios_order, + const mluOpTensorDescriptor_t output_desc, + void* output, + const mluOpTensorDescriptor_t var_desc, + void* var) { + mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(mluOpPriorBox(handle, + min_sizes_desc, + min_sizes, + aspect_ratios_desc, + aspect_ratios, + variances_desc, + variances, + max_sizes_desc, + max_sizes, + height, + width, + im_height, + im_width, + step_h, + step_w, + offset, + clip, + min_max_aspect_ratios_order, + output_desc, + output, + var_desc, + var)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index e56331b2728..f2c6a792ece 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include #include #include @@ -138,6 +139,54 @@ inline cnnlDataType_t ToCnnlDataType() { return ToCnnlDataType(type); } +inline mluOpDataType_t ToMluOpDataType( + const paddle::experimental::DataType& dtype) { + mluOpDataType_t type = MLUOP_DTYPE_FLOAT; + switch (dtype) { + case DataType::FLOAT16: + type = MLUOP_DTYPE_HALF; + break; + case DataType::FLOAT32: + type = MLUOP_DTYPE_FLOAT; + break; + case DataType::FLOAT64: + type = MLUOP_DTYPE_DOUBLE; + break; + case DataType::INT8: + type = MLUOP_DTYPE_INT8; + break; + case DataType::INT16: + type = MLUOP_DTYPE_INT16; + break; + case DataType::INT32: + type = MLUOP_DTYPE_INT32; + break; + case DataType::INT64: + type = MLUOP_DTYPE_INT64; + break; + case DataType::BOOL: + type = MLUOP_DTYPE_BOOL; + break; + case DataType::UINT8: + type = MLUOP_DTYPE_UINT8; + break; + default: + break; + } + return type; +} + +inline mluOpDataType_t ToMluOpDataType( + const paddle::framework::proto::VarType::Type& type) { + return ToMluOpDataType(framework::TransToPhiDataType(type)); +} + +template +inline mluOpDataType_t ToMluOpDataType() { + auto type = framework::ToDataType(std::type_index(typeid(T))); + return ToMluOpDataType(type); +} + // Converts (via narrowing) a type T value to a type U, and checks that the // value has no value change due to the conversion. template @@ -152,6 +201,10 @@ inline static cnnlHandle_t GetHandleFromCTX(const ExecutionContext& ctx) { return ctx.template device_context().cnnl_handle(); } +inline static mluOpHandle_t GetMLUOpHandleFromCTX(const ExecutionContext& ctx) { + return ctx.template device_context().mluOp_handle(); +} + inline static const MLUDeviceContext& GetDevCtxFromCTX( const ExecutionContext& ctx) { return ctx.template device_context(); @@ -281,6 +334,74 @@ class MLUCnnlTensorDesc { cnnlTensorDescriptor_t raw_tensor_desc = nullptr; }; +class MLUOpTensorDesc { + public: + MLUOpTensorDesc() {} + + // SE_DISALLOW_COPY_AND_ASSIGN + MLUOpTensorDesc(const MLUOpTensorDesc& desc) = delete; + MLUOpTensorDesc& operator=(const MLUOpTensorDesc&) = delete; + + MLUOpTensorDesc(MLUOpTensorDesc&& rhs) + : raw_tensor_desc(rhs.raw_tensor_desc) { + rhs.raw_tensor_desc = nullptr; + } + + MLUOpTensorDesc& operator=(MLUOpTensorDesc&& rhs); + + MLUOpTensorDesc(const int tensor_dim, + const int dim_sizes[], + const mluOpDataType_t tensor_dtype); + + MLUOpTensorDesc(const int tensor_dim, + const int dim_sizes[], + const mluOpDataType_t tensor_dtype, + const mluOpTensorLayout_t layout); + + MLUOpTensorDesc(const int tensor_dim, + const int dim_sizes[], + const mluOpDataType_t tensor_dtype, + int position); + + MLUOpTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const mluOpDataType_t tensor_dtype); + + MLUOpTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const mluOpDataType_t tensor_dtype, + const mluOpTensorLayout_t layout); + + MLUOpTensorDesc(const int tensor_dim, + const int64_t dim_sizes[], + const mluOpDataType_t tensor_dtype, + int position); + + MLUOpTensorDesc(const Tensor& tensor, + const mluOpTensorLayout_t layout, + const mluOpDataType_t tensor_dtype); + + explicit MLUOpTensorDesc(const Tensor& tensor); + + MLUOpTensorDesc(const Tensor& tensor, + mluOpTensorLayout_t layout, + const mluOpDataType_t tensor_dtype, + int position); + + MLUOpTensorDesc(const Tensor& tensor, + mluOpTensorLayout_t layout, + const mluOpDataType_t tensor_dtype, + int position, + float scale); + + ~MLUOpTensorDesc(); + + const mluOpTensorDescriptor_t get() const { return raw_tensor_desc; } + + private: + mluOpTensorDescriptor_t raw_tensor_desc = nullptr; +}; + class MLUCnnlActivationDesc { public: MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete; @@ -1921,6 +2042,28 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void SmoothL1LossForward(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t t_desc, + const void* target, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t y_desc, + void* y); + + static void SmoothL1LossBackward(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t target_desc, + const void* target, + const cnnlTensorDescriptor_t dy_desc, + const void* dy, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t dx_desc, + void* dx); + static void EmbeddingForward(const ExecutionContext& ctx, const int padding_idx, const cnnlTensorDescriptor_t weight_desc, @@ -2149,6 +2292,50 @@ class MLUCnnl { void* diff_x); }; +class MLUOP { + public: + static void OpYoloBox(const ExecutionContext& ctx, + const mluOpTensorDescriptor_t x_desc, + const void* x, + const mluOpTensorDescriptor_t img_size_desc, + const void* img_size, + const mluOpTensorDescriptor_t anchors_desc, + const void* anchors, + const int class_num, + const float conf_thresh, + const int downsample_ratio, + const bool clip_bbox, + const float scale, + const bool iou_aware, + const float iou_aware_factor, + const mluOpTensorDescriptor_t boxes_desc, + void* boxes, + const mluOpTensorDescriptor_t scores_desc, + void* scores); + + static void OpPriorBox(const ExecutionContext& ctx, + const mluOpTensorDescriptor_t min_sizes_desc, + const void* min_sizes, + const mluOpTensorDescriptor_t aspect_ratios_desc, + const void* aspect_ratios, + const mluOpTensorDescriptor_t variances_desc, + const void* variances, + const mluOpTensorDescriptor_t max_sizes_desc, + const void* max_sizes, + const int height, + const int width, + const int im_height, + const int im_width, + const float step_h, + const float step_w, + const float offset, + const bool clip, + const bool min_max_aspect_ratios_order, + const mluOpTensorDescriptor_t output_desc, + void* output, + const mluOpTensorDescriptor_t var_desc, + void* var); +}; const std::map, std::vector>> TransPermMap = { // trans_mode, (forward_perm, backward_perm) diff --git a/paddle/fluid/operators/one_hot_v2_op_mlu.cc b/paddle/fluid/operators/one_hot_v2_op_mlu.cc index f574cc525f1..21e8975e37d 100644 --- a/paddle/fluid/operators/one_hot_v2_op_mlu.cc +++ b/paddle/fluid/operators/one_hot_v2_op_mlu.cc @@ -97,4 +97,6 @@ class OneHotV2MLUKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_MLU_KERNEL(one_hot_v2, ops::OneHotV2MLUKernel); +REGISTER_OP_MLU_KERNEL(one_hot_v2, + ops::OneHotV2MLUKernel, + ops::OneHotV2MLUKernel); diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc index ecc527d5c72..aff468cc3c8 100644 --- a/paddle/fluid/operators/optimizers/adam_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -291,11 +291,38 @@ class AdamWMLUKernel : public AdamMLUKernel { skip_update = skip_update_vec[0]; } bool with_decay = ctx.Attr("with_decay"); + const bool multi_precision = ctx.Attr("multi_precision"); + auto* param_out = ctx.Output("ParamOut"); + auto* master_param_out = ctx.Output("MasterParamOut"); + const auto* master_param = ctx.Input("MasterParam"); + VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay; if (!skip_update && with_decay) { - if (ctx.HasInput("MasterParam")) { - PADDLE_THROW(platform::errors::Unimplemented( - "Master Param is not supported on MLU")); + auto* param = ctx.Input("Param"); + MLUCnnlTensorDesc param_desc(*param); + if (multi_precision) { + VLOG(3) << "[adamw] multi_precision, cast masterparam to param."; + bool has_master = + ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + PADDLE_ENFORCE_EQ( + has_master, + true, + platform::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + // cast masterparam (fp32) to param (fp16), then paramout (fp16) to + // masterparamout (fp32) + MLUCnnlTensorDesc master_param_desc(*master_param); + cnnlCastDataType_t cast_type = GetCastDataType( + framework::TransToProtoVarType(master_param->dtype()), + framework::TransToProtoVarType(param->dtype())); + MLUCnnl::Cast(ctx, + cast_type, + master_param_desc.get(), + GetBasePtr(master_param), + param_desc.get(), + const_cast(GetBasePtr(param))); } else { const auto* param_var = ctx.InputVar("Param"); PADDLE_ENFORCE_EQ(param_var->IsType(), @@ -305,13 +332,12 @@ class AdamWMLUKernel : public AdamMLUKernel { "but the received is %s", ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type()))); - auto* param = ctx.Input("Param"); + auto* lr = ctx.Input("LearningRate"); float coeff = ctx.Attr("coeff"); // update param with decay coeff: mul(-1 * lr, coeff * param) + param MLUCnnlTensorDesc lr_desc(*lr); - MLUCnnlTensorDesc param_desc(*param); MLUCnnlOpTensorDesc mul_op_desc( CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); @@ -330,9 +356,244 @@ class AdamWMLUKernel : public AdamMLUKernel { } } AdamMLUKernel::Compute(ctx); + if (multi_precision) { + VLOG(3) << "[adamw] multi_precision, cast paramout to masterparamout."; + // cast paramout to masterparamout + master_param_out->mutable_data(ctx.GetPlace()); + cnnlCastDataType_t cast_type = GetCastDataType( + framework::TransToProtoVarType(param_out->dtype()), + framework::TransToProtoVarType(master_param_out->dtype())); + MLUCnnlTensorDesc param_out_desc(*param_out); + MLUCnnlTensorDesc master_param_out_desc(*master_param_out); + + MLUCnnl::Cast(ctx, + cast_type, + param_out_desc.get(), + GetBasePtr(param_out), + master_param_out_desc.get(), + GetBasePtr(master_param_out)); + } } }; +template +class MergedAdamMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // Get inputs and outputs + auto params = ctx.MultiInput("Param"); + auto grads = ctx.MultiInput("Grad"); + auto lrs = ctx.MultiInput("LearningRate"); + auto mom1s = ctx.MultiInput("Moment1"); + auto mom2s = ctx.MultiInput("Moment2"); + auto beta1_pows = ctx.MultiInput("Beta1Pow"); + auto beta2_pows = ctx.MultiInput("Beta2Pow"); + auto master_params = ctx.MultiInput("MasterParam"); + auto param_outs = ctx.MultiOutput("ParamOut"); + auto mom1_outs = ctx.MultiOutput("Moment1Out"); + auto mom2_outs = ctx.MultiOutput("Moment2Out"); + auto beta1_pow_outs = ctx.MultiOutput("Beta1PowOut"); + auto beta2_pow_outs = ctx.MultiOutput("Beta2PowOut"); + + // Check validation of inputs and outputs + size_t param_num = params.size(); + PADDLE_ENFORCE_EQ(param_num, + param_outs.size(), + platform::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + param_outs.size(), + param_num)); + + bool skip_update = false; + if (ctx.HasInput("SkipUpdate")) { + auto* skip_update_tensor = ctx.Input("SkipUpdate"); + PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), + 1, + platform::errors::InvalidArgument( + "Input(SkipUpdate) size must be 1, but get %d", + skip_update_tensor->numel())); + std::vector skip_update_vec; + paddle::framework::TensorToVector( + *skip_update_tensor, ctx.device_context(), &skip_update_vec); + ctx.device_context().Wait(); + skip_update = skip_update_vec[0]; + } + // skip_update=true, just copy input to output, and TensorCopy will call + // mutable_data + + if (skip_update) { + VLOG(4) << "MergedAdam skip update"; + for (size_t i = 0; i < param_num; ++i) { + framework::TensorCopy( + *params[i], + ctx.GetPlace(), + ctx.template device_context(), + param_outs[i]); + framework::TensorCopy( + *mom1s[i], + ctx.GetPlace(), + ctx.template device_context(), + mom1_outs[i]); + framework::TensorCopy( + *mom2s[i], + ctx.GetPlace(), + ctx.template device_context(), + mom2_outs[i]); + framework::TensorCopy( + *beta1_pows[i], + beta1_pows[i]->place(), + ctx.template device_context(), + beta1_pow_outs[i]); + framework::TensorCopy( + *beta2_pows[i], + beta2_pows[i]->place(), + ctx.template device_context(), + beta2_pow_outs[i]); + } + return; + } + + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + // Get beta1, beta2 and epsilon from attribute. + const Tensor* beta1_tensor = nullptr; + const Tensor* beta2_tensor = nullptr; + const Tensor* epsilon_tensor = nullptr; + + Tensor beta1_tmp(experimental::DataType::FLOAT32); + Tensor beta2_tmp(experimental::DataType::FLOAT32); + Tensor epsilon_tmp(experimental::DataType::FLOAT32); + + T beta1 = static_cast(ctx.Attr("beta1")); + T beta2 = static_cast(ctx.Attr("beta2")); + T epsilon = static_cast(ctx.Attr("epsilon")); + beta1_tmp.mutable_data({1}, ctx.GetPlace()); + beta2_tmp.mutable_data({1}, ctx.GetPlace()); + epsilon_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta1_tmp_desc(beta1_tmp); + MLUCnnlTensorDesc beta2_tmp_desc(beta2_tmp); + MLUCnnlTensorDesc epsilon_tmp_desc(epsilon_tmp); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta1, + beta1_tmp_desc.get(), + GetBasePtr(&beta1_tmp)); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta2, + beta2_tmp_desc.get(), + GetBasePtr(&beta2_tmp)); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &epsilon, + epsilon_tmp_desc.get(), + GetBasePtr(&epsilon_tmp)); + beta1_tensor = &beta1_tmp; + beta2_tensor = &beta2_tmp; + epsilon_tensor = &epsilon_tmp; + + // Loop to compute + for (size_t i = 0; i < param_num; ++i) { + VLOG(4) << "[MergedAdam] loop: " << i; + param_outs[i]->ShareDataWith(*params[i]); + mom1_outs[i]->ShareDataWith(*mom1s[i]); + mom2_outs[i]->ShareDataWith(*mom2s[i]); + + LoDTensor beta1_pow_tmp; + LoDTensor beta2_pow_tmp; + if (beta1_pows[i]->place() == platform::CPUPlace()) { + T beta1 = *beta1_pows[i]->data(); + beta1_pow_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta1_pow_tmp_desc(beta1_pow_tmp); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta1, + beta1_pow_tmp_desc.get(), + GetBasePtr(&beta1_pow_tmp)); + beta1_pows[i] = &beta1_pow_tmp; + } + if (beta2_pows[i]->place() == platform::CPUPlace()) { + T beta2 = *beta2_pows[i]->data(); + beta2_pow_tmp.mutable_data({1}, ctx.GetPlace()); + MLUCnnlTensorDesc beta2_pow_tmp_desc(beta2_pow_tmp); + MLUCnnl::Fill(ctx, + CNNL_POINTER_MODE_HOST, + &beta2, + beta2_pow_tmp_desc.get(), + GetBasePtr(&beta2_pow_tmp)); + beta2_pows[i] = &beta2_pow_tmp; + } + + VLOG(3) << "beta1_pow.numel() : " << beta1_pows[i]->numel() + << "beta2_pow.numel() : " << beta2_pows[i]->numel(); + VLOG(3) << "param.numel(): " << params[i]->numel(); + PADDLE_ENFORCE_EQ(beta1_pow_outs[i]->numel(), + 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_outs[i]->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_outs[i]->numel(), + 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_outs[i]->numel())); + MLUCnnlTensorDesc param_desc(*params[i]); + MLUCnnlTensorDesc mom1_desc(*mom1s[i]); + MLUCnnlTensorDesc mom2_desc(*mom2s[i]); + MLUCnnlTensorDesc grad_desc(*grads[i]); + MLUCnnl::ApplyAdam(ctx, + param_desc.get(), + GetBasePtr(param_outs[i]), + mom1_desc.get(), + GetBasePtr(mom1_outs[i]), + mom2_desc.get(), + GetBasePtr(mom2_outs[i]), + grad_desc.get(), + GetBasePtr(grads[i]), + GetBasePtr(lrs[i]), + GetBasePtr(beta1_tensor), + GetBasePtr(beta2_tensor), + GetBasePtr(beta1_pows[i]), + GetBasePtr(beta2_pows[i]), + GetBasePtr(epsilon_tensor), + /*use_nesterov*/ false); + if (!use_global_beta_pow) { + beta1_pow_outs[i]->mutable_data(ctx.GetPlace()); + beta2_pow_outs[i]->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc beta1_desc(*beta1_tensor); + MLUCnnlOpTensorDesc mul_op_desc( + CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(ctx, + mul_op_desc.get(), + beta1_desc.get(), + GetBasePtr(beta1_pows[i]), + beta1_desc.get(), + GetBasePtr(beta1_tensor), + beta1_desc.get(), + GetBasePtr(beta1_pow_outs[i]), + ToCnnlDataType()); + + MLUCnnl::OpTensor(ctx, + mul_op_desc.get(), + beta1_desc.get(), + GetBasePtr(beta2_pows[i]), + beta1_desc.get(), + GetBasePtr(beta2_tensor), + beta1_desc.get(), + GetBasePtr(beta2_pow_outs[i]), + ToCnnlDataType()); + } + } + } +}; } // namespace operators } // namespace paddle @@ -346,3 +607,7 @@ REGISTER_OP_MLU_KERNEL(adam, REGISTER_OP_MLU_KERNEL(adamw, ops::AdamWMLUKernel, ops::AdamWMLUKernel); + +REGISTER_OP_MLU_KERNEL(merged_adam, + ops::MergedAdamMLUKernel, + ops::MergedAdamMLUKernel); diff --git a/paddle/fluid/operators/pool_op_mlu.cc b/paddle/fluid/operators/pool_op_mlu.cc index 5eaf8dbff88..988eb182a16 100644 --- a/paddle/fluid/operators/pool_op_mlu.cc +++ b/paddle/fluid/operators/pool_op_mlu.cc @@ -141,10 +141,9 @@ class MLUPoolOpKernel : public framework::OpKernel { handle, pool_mode, out_w, out_h, &extra_input_size); if (extra_input_size > 0) { - phi::CPUContext cpu_ctx; - framework::Tensor extra_host_tensor = - ctx.AllocateTmpTensor( - {static_cast(extra_input_size)}, cpu_ctx); + framework::Tensor extra_host_tensor; + extra_host_tensor.mutable_data( + {static_cast(extra_input_size)}, platform::CPUPlace()); cnnlInitPoolingExtraInput(handle, pool_desc.get(), trans_in_x_desc.get(), diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc index 310c1db205d..75b0c1f16de 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op_mlu.cc @@ -92,6 +92,112 @@ class ReduceMaxMLUKernel : public framework::OpKernel { } }; +template +class ReduceMaxGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto reduce_dims = context.Attr>("dim"); + bool reduce_all = context.Attr("reduce_all"); + int in_dtype = context.Attr("in_dtype"); + + PADDLE_ENFORCE_EQ( + in_dtype == -1, + true, + platform::errors::InvalidArgument( + "MLU only support in_dtype == -1 in reduce_max_grad op.")); + auto* x_grad = context.Output(framework::GradVarName("X")); + x_grad->mutable_data(context.GetPlace()); + + auto place = context.GetPlace(); + + // broadcast + auto x_dims_vec = phi::vectorize(x->dims()); + if (reduce_all) { + reduce_dims.clear(); + for (size_t d = 0; d < x_dims_vec.size(); ++d) { + reduce_dims.push_back(static_cast(d)); + } + } + + Tensor tmp_out, tmp_out_grad; + auto tmp_out_dims_vec = x_dims_vec; + for (auto d : reduce_dims) { + if (d < 0) { + d += x_dims_vec.size(); + } + tmp_out_dims_vec[d] = 1; + } + + tmp_out.ShareDataWith(*out); + tmp_out.Resize(phi::make_ddim(tmp_out_dims_vec)); + tmp_out_grad.ShareDataWith(*out_grad); + tmp_out_grad.Resize(phi::make_ddim(tmp_out_dims_vec)); + + Tensor transformed_out(x->type()); + transformed_out.Resize(phi::make_ddim(x_dims_vec)); + transformed_out.mutable_data(place); + + MLUCnnlTensorDesc tmp_out_desc(tmp_out); + MLUCnnlTensorDesc transformed_out_desc(transformed_out); + + MLUCnnl::BroadcastTo(context, + tmp_out_desc.get(), + GetBasePtr(&tmp_out), + transformed_out_desc.get(), + GetBasePtr(&transformed_out)); + + Tensor transformed_out_grad(x->type()); + transformed_out_grad.Resize(phi::make_ddim(x_dims_vec)); + transformed_out_grad.mutable_data(place); + MLUCnnlTensorDesc tmp_out_grad_desc(tmp_out_grad); + MLUCnnlTensorDesc transformed_out_grad_desc(transformed_out_grad); + + MLUCnnl::BroadcastTo(context, + tmp_out_grad_desc.get(), + GetBasePtr(&tmp_out_grad), + transformed_out_grad_desc.get(), + GetBasePtr(&transformed_out_grad)); + + // compare + Tensor equal_cond; + equal_cond.mutable_data(x_grad->dims(), place); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc equal_cond_desc(equal_cond); + + MLUCnnl::Logic(context, + CNNL_LOGIC_OP_EQ, + x_desc.get(), + GetBasePtr(x), + transformed_out_desc.get(), + GetBasePtr(&transformed_out), + equal_cond_desc.get(), + GetBasePtr(&equal_cond)); + + // select + Tensor t_zero; + t_zero.mutable_data(x_grad->dims(), place); + FillMLUTensorWithHostValue(context, static_cast(0), &t_zero); + t_zero.Resize(x_grad->dims()); + + MLUCnnlTensorDesc t_zero_desc(t_zero); + MLUCnnlTensorDesc x_grad_desc(*x_grad); + + MLUCnnl::Select(context, + equal_cond_desc.get(), + GetBasePtr(&equal_cond), + transformed_out_grad_desc.get(), + GetBasePtr(&transformed_out_grad), + t_zero_desc.get(), + GetBasePtr(&t_zero), + x_grad_desc.get(), + GetBasePtr(x_grad)); + } +}; + } // namespace operators } // namespace paddle @@ -102,3 +208,7 @@ REGISTER_OP_MLU_KERNEL(reduce_max, ops::ReduceMaxMLUKernel, ops::ReduceMaxMLUKernel, ops::ReduceMaxMLUKernel); +REGISTER_OP_MLU_KERNEL(reduce_max_grad, + ops::ReduceMaxGradMLUKernel, + ops::ReduceMaxGradMLUKernel, + ops::ReduceMaxGradMLUKernel); diff --git a/paddle/fluid/operators/strided_slice_op_mlu.cc b/paddle/fluid/operators/strided_slice_op_mlu.cc index 95972d81592..81d5b9089a9 100644 --- a/paddle/fluid/operators/strided_slice_op_mlu.cc +++ b/paddle/fluid/operators/strided_slice_op_mlu.cc @@ -19,6 +19,11 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = phi::DenseTensor; +using Variable = framework::Variable; +using LoDTensorArray = framework::LoDTensorArray; +using DDim = framework::DDim; + static void ProcessStridedSliceParams( const std::vector& axes, const DDim& input_dims, diff --git a/paddle/fluid/platform/device/mlu/device_context.cc b/paddle/fluid/platform/device/mlu/device_context.cc index 087b4803320..796d7006834 100644 --- a/paddle/fluid/platform/device/mlu/device_context.cc +++ b/paddle/fluid/platform/device/mlu/device_context.cc @@ -28,11 +28,13 @@ MLUContext::MLUContext(const MLUPlace& place, const int priority) { MLUDeviceGuard guard(place_.device); stream_.reset(new stream::MLUStream(place_, priority)); InitCNNLContext(); + InitMLUOPContext(); } MLUContext::~MLUContext() { MLUDeviceGuard guard(place_.device); DestoryCNNLContext(); + DestoryMLUOPContext(); } MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { @@ -41,6 +43,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { driver_version_ = GetMLUDriverVersion(place_.device); runtime_version_ = GetMLURuntimeVersion(place_.device); cnnl_version_ = GetMLUCnnlVersion(place_.device); + mluOp_version_ = GetMLUOpVersion(place_.device); LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << static_cast(place_.device) @@ -51,7 +54,9 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { << ", Runtime API Version: " << runtime_version_ / 10000 << "." << (runtime_version_ / 100) % 100 << "." << runtime_version_ % 100 << ", Cnnl API Version: " << cnnl_version_ / 10000 << "." - << (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100; + << (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100 + << ", MluOp API Version: " << mluOp_version_ / 10000 << "." + << (mluOp_version_ / 100) % 100 << "." << mluOp_version_ % 100; default_ctx_.reset(new MLUContext(place_)); } @@ -70,6 +75,10 @@ mluCnnlHandle MLUDeviceContext::cnnl_handle() const { return context()->CnnlHandle(); } +mluOpHandle MLUDeviceContext::mluOp_handle() const { + return context()->MluOpHandle(); +} + mluStream MLUDeviceContext::stream() const { return context()->RawStream(); } #endif diff --git a/paddle/fluid/platform/device/mlu/device_context.h b/paddle/fluid/platform/device/mlu/device_context.h index d8bb7623159..e1028667bc2 100644 --- a/paddle/fluid/platform/device/mlu/device_context.h +++ b/paddle/fluid/platform/device/mlu/device_context.h @@ -53,12 +53,19 @@ class MLUContext { const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; } + const mluOpHandle& MluOpHandle() const { return mluOp_handle_; } + private: void InitCNNLContext() { PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_)); PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream())); } + void InitMLUOPContext() { + PADDLE_ENFORCE_MLU_SUCCESS(mluOpCreate(&mluOp_handle_)); + PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetQueue(mluOp_handle_, RawStream())); + } + void DestoryCNNLContext() { if (cnnl_handle_) { PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_)); @@ -66,10 +73,18 @@ class MLUContext { cnnl_handle_ = nullptr; } + void DestoryMLUOPContext() { + if (mluOp_handle_) { + PADDLE_ENFORCE_MLU_SUCCESS(mluOpDestroy(mluOp_handle_)); + } + mluOp_handle_ = nullptr; + } + MLUPlace place_; std::unique_ptr eigen_device_; std::unique_ptr stream_; mluCnnlHandle cnnl_handle_; + mluOpHandle mluOp_handle_; DISABLE_COPY_AND_ASSIGN(MLUContext); }; @@ -89,6 +104,9 @@ class MLUDeviceContext : public DeviceContext { /*! \brief Return cnnl handle in the device context. */ mluCnnlHandle cnnl_handle() const; + /*! \brief Return mluOp handle in the device context. */ + mluOpHandle mluOp_handle() const; + /*! \brief Return mlu stream in the device context. */ mluStream stream() const; @@ -135,6 +153,7 @@ class MLUDeviceContext : public DeviceContext { int driver_version_; int runtime_version_; int cnnl_version_; + int mluOp_version_; MLUPlace place_; std::shared_ptr default_ctx_; diff --git a/paddle/fluid/platform/device/mlu/enforce.h b/paddle/fluid/platform/device/mlu/enforce.h index 05327a771d8..8b0d0bb36f5 100644 --- a/paddle/fluid/platform/device/mlu/enforce.h +++ b/paddle/fluid/platform/device/mlu/enforce.h @@ -41,6 +41,7 @@ struct MLUStatusType {}; DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT); DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL); +DEFINE_MLU_STATUS_TYPE(mluOpStatus, MLUOP_STATUS_SUCCESS, MLUOP); DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN); #ifdef PADDLE_WITH_CNCL DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL); @@ -68,6 +69,15 @@ inline std::string build_mlu_error_msg(cnnlStatus stat) { return sout.str(); } +/*************** MLU OP ERROR ***************/ +inline bool is_error(mluOpStatus stat) { return stat != MLUOP_STATUS_SUCCESS; } + +inline std::string build_mlu_error_msg(mluOpStatus stat) { + std::ostringstream sout; + sout << "MLU OP error(" << stat << "), " << mluOpGetErrorString(stat) << ". "; + return sout.str(); +} + /*************** CN API ERROR ***************/ inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; } diff --git a/paddle/fluid/platform/device/mlu/mlu_info.cc b/paddle/fluid/platform/device/mlu/mlu_info.cc index e27720849e0..a2e063397bd 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.cc +++ b/paddle/fluid/platform/device/mlu/mlu_info.cc @@ -126,6 +126,13 @@ int GetMLUCnnlVersion(int id) { return x * 10000 + y * 100 + z; } +int GetMLUOpVersion(int id) { + CheckDeviceId(id); + int x, y, z; + mluOpGetLibVersion(&x, &y, &z); + return x * 10000 + y * 100 + z; +} + int GetMLUCurrentDeviceId() { int device_id; PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id)); diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h index 14f37879ef0..c0cd24f00fb 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.h +++ b/paddle/fluid/platform/device/mlu/mlu_info.h @@ -16,10 +16,11 @@ limitations under the License. */ #ifdef PADDLE_WITH_MLU #include -#include #include #include +#include #include +#include #ifdef PADDLE_WITH_CNCL #include #endif @@ -30,11 +31,13 @@ namespace paddle { using cnStatus = CNresult; using cnrtStatus = cnrtRet_t; using cnnlStatus = cnnlStatus_t; +using mluOpStatus = mluOpStatus_t; #ifdef PADDLE_WITH_CNCL using cnclStatus = cnclResult_t; #endif using mluStream = cnrtQueue_t; using mluCnnlHandle = cnnlHandle_t; +using mluOpHandle = mluOpHandle_t; using mluEventHandle = cnrtNotifier_t; using mluDeviceHandle = CNdev; @@ -49,6 +52,9 @@ int GetMLURuntimeVersion(int id); //! Get the cnnl version of the ith MLU. int GetMLUCnnlVersion(int id); +//! Get the mluOp version of the ith MLU. +int GetMLUOpVersion(int id); + //! Get the total number of MLU devices in system. int GetMLUDeviceCount(); diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index 5957c4c24ca..72fb647a04e 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -29,7 +29,10 @@ #include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h" #include "paddle/fluid/platform/profiler/extra_info.h" #include "paddle/fluid/platform/profiler/host_tracer.h" +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/enforce.h" #include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h" +#endif #include "paddle/fluid/platform/profiler/trace_event_collector.h" #include "paddle/fluid/platform/profiler/utils.h" @@ -80,9 +83,11 @@ Profiler::Profiler(const ProfilerOptions& options, if (trace_switch.test(kProfileGPUOptionBit)) { tracers_.emplace_back(&CudaTracer::GetInstance(), false); } +#ifdef PADDLE_WITH_MLU if (trace_switch.test(kProfileMLUOptionBit)) { tracers_.emplace_back(&MluTracer::GetInstance(), false); } +#endif if (trace_switch.test(kProfileCustomDeviceOptionBit)) { for (const auto& dev_type : custom_device_types) { tracers_.emplace_back(&CustomTracer::GetInstance(dev_type), false); diff --git a/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py index 5c1b8b602f2..e7c50ec2880 100644 --- a/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py @@ -34,7 +34,10 @@ import unittest from multiprocessing import Process import paddle.fluid.layers as layers from functools import reduce -from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main +from test_sync_batch_norm_base_mlu import ( + TestSyncBatchNormRunnerBase, + runtime_main, +) from op_test import OpTest, _set_use_system_allocator from test_sync_batch_norm_op import create_or_get_tensor @@ -44,11 +47,11 @@ paddle.enable_static() class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): - def __init__(self): self.global_ring_id = 0 self.dtype = np.float32 + self.bn_dtype = np.float32 self.N = 8 self.C = 16 self.H = 32 @@ -56,29 +59,36 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): self.dshape = [self.N, self.C, self.H, self.W] self.atol = 1e-3 - def get_model(self, - main, - startup, - place, - layout, - seed, - sync_bn=False, - only_forward=False): + def get_model( + self, + main, + startup, + place, + layout, + seed, + sync_bn=False, + only_forward=False, + ): """Build program.""" use_cudnn = False with fluid.unique_name.guard(): with fluid.program_guard(main, startup): - data = fluid.layers.data(name='input', - shape=self.dshape, - dtype=self.dtype, - append_batch_size=False) + data = fluid.layers.data( + name='input', + shape=self.dshape, + dtype=self.dtype, + append_batch_size=False, + ) conv = fluid.layers.conv2d( input=data, num_filters=32, filter_size=1, param_attr=fluid.ParamAttr(name='conv2d_weight'), bias_attr=False, - use_cudnn=use_cudnn) + use_cudnn=use_cudnn, + ) + if self.bn_dtype == np.float16: + conv = fluid.layers.cast(conv, 'float16') bn = fluid.layers.batch_norm( conv, param_attr=fluid.ParamAttr(name='bn_scale'), @@ -86,9 +96,10 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): moving_mean_name='bn_moving_mean', moving_variance_name='bn_moving_variance', data_layout=layout, - is_test=only_forward) - # if self.dtype == np.float16: - # bn = fluid.layers.cast(bn, 'float32') + is_test=only_forward, + ) + if self.bn_dtype == np.float16: + bn = fluid.layers.cast(bn, 'float32') sigmoid = fluid.layers.sigmoid(bn) out = fluid.layers.reduce_sum(sigmoid) # if not sync_bn: diff --git a/python/paddle/fluid/tests/unittests/mlu/test_collective_api_base_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_collective_api_base_mlu.py index 1b3ce961115..5ee74790756 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_collective_api_base_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_collective_api_base_mlu.py @@ -41,10 +41,10 @@ def DataTypeCast(date_type): class TestCollectiveAPIRunnerBase(object): - def get_model(self, train_prog, startup_prog, rank, indata=None): raise NotImplementedError( - "get model should be implemented by child class.") + "get model should be implemented by child class." + ) def run_trainer(self, args): train_prog = fluid.Program() @@ -66,12 +66,12 @@ class TestCollectiveAPIRunnerBase(object): fetch_list = [] for elem in result: fetch_list.append(elem.name) - out = exe.run(train_prog, - feed={'tindata': indata}, - fetch_list=fetch_list) + out = exe.run( + train_prog, feed={'tindata': indata}, fetch_list=fetch_list + ) else: out = self.get_model(train_prog, startup_prog, rank, indata) - #print(out, sys.stderr) + # print(out, sys.stderr) sys.stdout.buffer.write(pickle.dumps(out)) @@ -96,19 +96,20 @@ from contextlib import closing class TestDistBase(unittest.TestCase): - def setUp(self): self._port_set = set() self._trainers = 2 self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( - self._find_free_port(), self._find_free_port()) + self._find_free_port(), + self._find_free_port(), + ) self._python_interp = sys.executable def _find_free_port(self): - def __free_port(): - with closing(socket.socket(socket.AF_INET, - socket.SOCK_STREAM)) as s: + with closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as s: s.bind(('', 0)) return s.getsockname()[1] @@ -121,13 +122,13 @@ class TestDistBase(unittest.TestCase): def _run_cluster(self, model_file, envs): worker_endpoints = self._ps_endpoints.split(",") w0_ep, w1_ep = worker_endpoints - #print("w0_ep:",w0_ep," w1_ep:",w1_ep) + # print("w0_ep:",w0_ep," w1_ep:",w1_ep) env0 = { "FLAGS_selected_mlus": "0", "PADDLE_TRAINER_ID": "0", "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, - "PADDLE_CURRENT_ENDPOINT": w0_ep + "PADDLE_CURRENT_ENDPOINT": w0_ep, } env1 = { @@ -135,9 +136,9 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINER_ID": "1", "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, - "PADDLE_CURRENT_ENDPOINT": w1_ep + "PADDLE_CURRENT_ENDPOINT": w1_ep, } - #update environment + # update environment env0.update(envs) env1.update(envs) if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': @@ -148,16 +149,20 @@ class TestDistBase(unittest.TestCase): tr1_cmd = tr_cmd % (self._python_interp, model_file) tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w") tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w") - #print(tr0_cmd) - tr0_proc = subprocess.Popen(tr0_cmd.strip().split(), - stdout=subprocess.PIPE, - stderr=tr0_pipe, - env=env0) - - tr1_proc = subprocess.Popen(tr0_cmd.strip().split(), - stdout=subprocess.PIPE, - stderr=tr1_pipe, - env=env1) + # print(tr0_cmd) + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr0_pipe, + env=env0, + ) + + tr1_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr1_pipe, + env=env1, + ) tr0_out, tr0_err = tr0_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate() @@ -170,17 +175,23 @@ class TestDistBase(unittest.TestCase): sys.stderr.write('trainer 0 stderr file: %s\n' % f.read()) with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f: sys.stderr.write('trainer 1 stderr file: %s\n' % f.read()) - return pickle.loads(tr0_out), pickle.loads( - tr1_out), tr0_proc.pid, tr1_proc.pid - - def check_with_place(self, - model_file, - col_type, - data_type, - path_id="0", - static_mode="1", - check_error_log=False, - need_envs={}): + return ( + pickle.loads(tr0_out), + pickle.loads(tr1_out), + tr0_proc.pid, + tr1_proc.pid, + ) + + def check_with_place( + self, + model_file, + col_type, + data_type, + path_id="0", + static_mode="1", + check_error_log=False, + need_envs={}, + ): required_envs = { "FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_eager_delete_tensor_gb": "0.0", @@ -194,7 +205,7 @@ class TestDistBase(unittest.TestCase): "PADDLE_WITH_GLOO": '0', "BACKEND": "cncl", "PATH_ID": path_id, - "DATA_TYPE": data_type + "DATA_TYPE": data_type, } required_envs.update(need_envs) if check_error_log: @@ -202,7 +213,8 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_logtostderr"] = "1" required_envs["GLOO_LOG_LEVEL"] = "TRACE" tr0_out, tr1_out, pid0, pid1 = self._run_cluster( - model_file, required_envs) + model_file, required_envs + ) np_data_type = DataTypeCast(data_type) np.random.seed(pid0) input1 = np.random.random((10, 1000)).astype(np_data_type) @@ -210,21 +222,19 @@ class TestDistBase(unittest.TestCase): input2 = np.random.random((10, 1000)).astype(np_data_type) if col_type == "broadcast": need_result = input2 - np.testing.assert_allclose(tr0_out, need_result) - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr0_out[0], need_result) + np.testing.assert_allclose(tr1_out[0], need_result) elif col_type == "allreduce": need_result = input1 + input2 - np.testing.assert_allclose(tr0_out, - need_result, - rtol=1e-05, - atol=1e-05) - np.testing.assert_allclose(tr1_out, - need_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + tr0_out[0], need_result, rtol=1e-05, atol=1e-05 + ) + np.testing.assert_allclose( + tr1_out[0], need_result, rtol=1e-05, atol=1e-05 + ) elif col_type == "reduce": need_result = input1 + input2 - np.testing.assert_allclose(tr0_out, need_result) + np.testing.assert_allclose(tr0_out[0], need_result) elif col_type == "allgather": need_result = np.vstack((input1, input2)) tr_out0 = np.vstack((tr0_out[0], tr0_out[1])) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py index 47fb3a1a230..cb6444056b0 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_collective_base_mlu.py @@ -53,10 +53,10 @@ def DataTypeCast(date_type): class TestCollectiveRunnerBase(object): - def get_model(self, train_prog, startup_prog, col_type): raise NotImplementedError( - "get model should be implemented by child class.") + "get model should be implemented by child class." + ) def wait_server_ready(self, endpoints): while True: @@ -64,13 +64,15 @@ class TestCollectiveRunnerBase(object): not_ready_endpoints = [] for ep in endpoints: ip_port = ep.split(":") - with closing(socket.socket(socket.AF_INET, - socket.SOCK_STREAM)) as sock: + with closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as sock: sock.settimeout(2) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, - 1) + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT, 1 + ) result = sock.connect_ex((ip_port[0], int(ip_port[1]))) if result != 0: @@ -78,44 +80,51 @@ class TestCollectiveRunnerBase(object): not_ready_endpoints.append(ep) if not all_ok: sys.stderr.write("server not ready, wait 3 sec to retry...\n") - sys.stderr.write("not ready endpoints:" + - str(not_ready_endpoints) + "\n") + sys.stderr.write( + "not ready endpoints:" + str(not_ready_endpoints) + "\n" + ) sys.stderr.flush() time.sleep(3) else: break + # endpoints should be ["ip1:port1","ip2:port2"] -#endpoints should be ["ip1:port1","ip2:port2"] - - def initCommunicator(self, program, rank, nranks, wait_port, - current_endpoint, endpoints): + def initCommunicator( + self, program, rank, nranks, wait_port, current_endpoint, endpoints + ): other_endpoints = endpoints[:] other_endpoints.remove(current_endpoint) if rank == 0 and wait_port: self.wait_server_ready(other_endpoints) block = program.global_block() - cncl_id_var = block.create_var(name=nameGen.generate('cncl_id'), - persistable=True, - type=core.VarDesc.VarType.RAW) - - block.append_op(type='c_gen_cncl_id', - inputs={}, - outputs={'Out': cncl_id_var}, - attrs={ - 'rank': rank, - 'endpoint': current_endpoint, - 'other_endpoints': other_endpoints - }) - - block.append_op(type='c_comm_init', - inputs={'X': cncl_id_var}, - outputs={}, - attrs={ - 'nranks': nranks, - 'rank': rank, - 'ring_id': self.global_ring_id - }) + cncl_id_var = block.create_var( + name=nameGen.generate('cncl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + + block.append_op( + type='c_gen_cncl_id', + inputs={}, + outputs={'Out': cncl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + }, + ) + + block.append_op( + type='c_comm_init', + inputs={'X': cncl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': self.global_ring_id, + }, + ) def run_trainer(self, args): train_prog = fluid.Program() @@ -124,8 +133,9 @@ class TestCollectiveRunnerBase(object): rank = args["trainerid"] current_endpoint = args["currentendpoint"] nranks = 2 - self.initCommunicator(startup_prog, rank, nranks, True, - current_endpoint, endpoints) + self.initCommunicator( + startup_prog, rank, nranks, True, current_endpoint, endpoints + ) self.rank = rank result = self.get_model(train_prog, startup_prog, args["col_type"]) device_id = int(os.getenv("FLAGS_selected_mlus", "0")) @@ -135,9 +145,9 @@ class TestCollectiveRunnerBase(object): np.random.seed(os.getpid()) np_data_type = DataTypeCast(args["data_type"]) indata = np.random.random((10, 1000)).astype(np_data_type) - out = exe.run(train_prog, - feed={'tindata': indata}, - fetch_list=[result.name]) + out = exe.run( + train_prog, feed={'tindata': indata}, fetch_list=[result.name] + ) sys.stdout.buffer.write(pickle.dumps(out)) @@ -160,19 +170,20 @@ from contextlib import closing class TestDistBase(unittest.TestCase): - def setUp(self): self._port_set = set() self._trainers = 2 self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( - self._find_free_port(), self._find_free_port()) + self._find_free_port(), + self._find_free_port(), + ) self._python_interp = sys.executable def _find_free_port(self): - def __free_port(): - with closing(socket.socket(socket.AF_INET, - socket.SOCK_STREAM)) as s: + with closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as s: s.bind(('', 0)) return s.getsockname()[1] @@ -191,7 +202,7 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINER_ID": "0", "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, - "PADDLE_CURRENT_ENDPOINT": w0_ep + "PADDLE_CURRENT_ENDPOINT": w0_ep, } env1 = { @@ -199,9 +210,9 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINER_ID": "1", "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, - "PADDLE_CURRENT_ENDPOINT": w1_ep + "PADDLE_CURRENT_ENDPOINT": w1_ep, } - #update environment + # update environment env0.update(envs) env1.update(envs) tr_cmd = "%s %s" @@ -210,15 +221,19 @@ class TestDistBase(unittest.TestCase): tr0_pipe = open("/tmp/tr0_err.log", "wb") tr1_pipe = open("/tmp/tr1_err.log", "wb") - tr0_proc = subprocess.Popen(tr0_cmd.strip().split(), - stdout=subprocess.PIPE, - stderr=tr0_pipe, - env=env0) + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr0_pipe, + env=env0, + ) - tr1_proc = subprocess.Popen(tr0_cmd.strip().split(), - stdout=subprocess.PIPE, - stderr=tr1_pipe, - env=env1) + tr1_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr1_pipe, + env=env1, + ) tr0_out, tr0_err = tr0_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate() @@ -227,15 +242,21 @@ class TestDistBase(unittest.TestCase): # close trainer file tr0_pipe.close() tr1_pipe.close() - return pickle.loads(tr0_out), pickle.loads( - tr1_out), tr0_proc.pid, tr1_proc.pid - - def check_with_place(self, - model_file, - col_type, - data_type, - check_error_log=False, - need_envs={}): + return ( + pickle.loads(tr0_out), + pickle.loads(tr1_out), + tr0_proc.pid, + tr1_proc.pid, + ) + + def check_with_place( + self, + model_file, + col_type, + data_type, + check_error_log=False, + need_envs={}, + ): required_envs = { "FLAGS_eager_delete_tensor_gb": "0.0", "PATH": os.getenv("PATH"), @@ -251,7 +272,8 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_v"] = "3" required_envs["GLOG_logtostderr"] = "1" tr0_out, tr1_out, pid0, pid1 = self._run_cluster( - model_file, required_envs) + model_file, required_envs + ) np_data_type = DataTypeCast(data_type) np.random.seed(pid0) input1 = np.random.random((10, 1000)).astype(np_data_type) @@ -259,63 +281,55 @@ class TestDistBase(unittest.TestCase): input2 = np.random.random((10, 1000)).astype(np_data_type) if col_type == "broadcast": need_result = input2 - np.testing.assert_allclose(tr0_out, need_result) - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr0_out[0], need_result) + np.testing.assert_allclose(tr1_out[0], need_result) elif col_type == "allreduce_sum": need_result = input1 + input2 - np.testing.assert_allclose(tr0_out, - need_result, - rtol=1e-05, - atol=1e-05) - np.testing.assert_allclose(tr1_out, - need_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + tr0_out[0], need_result, rtol=1e-05, atol=1e-05 + ) + np.testing.assert_allclose( + tr1_out[0], need_result, rtol=1e-05, atol=1e-05 + ) elif col_type == "allreduce_prod": need_result = input1 * input2 - np.testing.assert_allclose(tr0_out, - need_result, - rtol=1e-05, - atol=1e-05) - np.testing.assert_allclose(tr1_out, - need_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + tr0_out[0], need_result, rtol=1e-05, atol=1e-05 + ) + np.testing.assert_allclose( + tr1_out[0], need_result, rtol=1e-05, atol=1e-05 + ) elif col_type == "allreduce_max": need_result = np.maximum(input1, input2) - np.testing.assert_allclose(tr0_out, - need_result, - rtol=1e-05, - atol=1e-05) - np.testing.assert_allclose(tr1_out, - need_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + tr0_out[0], need_result, rtol=1e-05, atol=1e-05 + ) + np.testing.assert_allclose( + tr1_out[0], need_result, rtol=1e-05, atol=1e-05 + ) elif col_type == "allreduce_min": need_result = np.minimum(input1, input2) - np.testing.assert_allclose(tr0_out, - need_result, - rtol=1e-05, - atol=1e-05) - np.testing.assert_allclose(tr1_out, - need_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + tr0_out[0], need_result, rtol=1e-05, atol=1e-05 + ) + np.testing.assert_allclose( + tr1_out[0], need_result, rtol=1e-05, atol=1e-05 + ) elif col_type == "reduce_sum": need_result = input1 + input2 - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr1_out[0], need_result) elif col_type == "reduce_prod": need_result = input1 * input2 - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr1_out[0], need_result) elif col_type == "reduce_max": need_result = np.maximum(input1, input2) - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr1_out[0], need_result) elif col_type == "reduce_min": need_result = np.minimum(input1, input2) - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr1_out[0], need_result) elif col_type == "allgather": need_result = np.vstack((input1, input2)) - np.testing.assert_allclose(tr0_out, need_result) - np.testing.assert_allclose(tr1_out, need_result) + np.testing.assert_allclose(tr0_out[0], need_result) + np.testing.assert_allclose(tr1_out[0], need_result) else: pass diff --git a/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py index 8497853561d..bedcde9def6 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_dropout_op_mlu.py @@ -29,26 +29,44 @@ SEED = 2022 class TestDropoutOp(OpTest): - def setUp(self): - self.op_type = "dropout" self.set_mlu() self.init_dtype() - self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} + self.init_inputs_shape() + self.init_attrs() + self.op_type = 'dropout' + self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)} self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('uint8') + 'dropout_prob': self.dropout_prob, + 'fix_seed': self.fix_seed, + 'is_test': self.is_test, + 'dropout_implementation': self.dropout_implementation, } + out = self.inputs['X'] * (1.0 - self.dropout_prob) + if self.is_test == False: + mask = None + if self.dropout_prob == 0.0: + mask = np.ones(self.shape).astype('uint8') + elif self.dropout_prob == 1.0: + mask = np.zeros(self.shape).astype('uint8') + self.outputs = {'Out': out, 'Mask': mask} + else: + self.outputs = {'Out': out} + def init_dtype(self): self.dtype = np.float32 + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.__class__.no_need_check_grad = False + self.dropout_prob = 0.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" + def set_mlu(self): self.__class__.use_mlu = True self.place = paddle.device.MLUPlace(0) @@ -57,84 +75,107 @@ class TestDropoutOp(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): + if ( + hasattr(self.__class__, "no_need_check_grad") + and self.__class__.no_need_check_grad == True + ): + return + self.check_grad_with_place(self.place, ['X'], 'Out') class TestDropoutOpInput1d(TestDropoutOp): - # change input shape - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((3, 62)).astype(self.dtype)} - self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((3, 62)).astype('uint8') - } - - -class TestDropoutOpInput1d_1(TestDropoutOp): - # the input is 1-D - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((2000)).astype(self.dtype)} - self.attrs = { - 'dropout_prob': 0.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((2000)).astype('uint8') - } + def init_inputs_shape(self): + self.shape = [2000] class TestDropoutOp2(TestDropoutOp): - # the dropout_prob is 1.0 - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} - self.attrs = { - 'dropout_prob': 1.0, - 'fix_seed': True, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': np.zeros((32, 64)).astype('float32'), - 'Mask': np.zeros((32, 64)).astype('uint8') - } + def init_inputs_shape(self): + self.shape = [32, 64] + + def init_attrs(self): + self.dropout_prob = 1.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "upscale_in_train" class TestDropoutOp3(TestDropoutOp): - # the input dim is 3 + def init_inputs_shape(self): + self.shape = [32, 64, 2] + + +class TestDropoutOp4(TestDropoutOp): + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.35 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + +class TestDropoutOp5(TestDropoutOp): + def init_inputs_shape(self): + self.shape = [32, 64, 3] + + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.75 + self.fix_seed = True + self.is_test = True + self.dropout_implementation = "downgrade_in_infer" + + +class TestDropoutOp6(TestDropoutOp): + def init_attrs(self): + self.__class__.no_need_check_grad = True + self.dropout_prob = 0.0 + self.fix_seed = True + self.is_test = False + self.dropout_implementation = "downgrade_in_infer" + + +class TestDropoutOpWithSeed(TestDropoutOp): + # the seed is a Tensor def setUp(self): self.op_type = "dropout" self.set_mlu() - self.init_dtype() - self.inputs = {'X': np.random.random((32, 64, 2)).astype(self.dtype)} + self.dtype = np.float32 + self.inputs = { + "X": np.random.random((32, 64)).astype(self.dtype), + "Seed": np.asarray([125], dtype="int32"), + } self.attrs = { 'dropout_prob': 0.0, - 'fix_seed': True, 'is_test': False, - 'dropout_implementation': 'upscale_in_train' + 'dropout_implementation': 'upscale_in_train', } self.outputs = { 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('uint8') + 'Mask': np.ones((32, 64)).astype('uint8'), } + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestDropoutOpFp16(TestDropoutOp): + # float16 + def init_dtype(self): + self.dtype = np.float16 + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + self.__class__.no_need_check_grad = True + @skip_check_grad_ci(reason="For inference, check_grad is not required.") class TestDropoutOpInference(OpTest): @@ -148,7 +189,7 @@ class TestDropoutOpInference(OpTest): 'dropout_prob': 0.35, 'fix_seed': True, 'is_test': True, - 'dropout_implementation': 'upscale_in_train' + 'dropout_implementation': 'upscale_in_train', } self.outputs = {'Out': self.inputs['X']} @@ -165,7 +206,6 @@ class TestDropoutOpInference(OpTest): @skip_check_grad_ci(reason="For inference, check_grad is not required.") class TestDropoutOpInference2(TestDropoutOpInference): - def setUp(self): self.op_type = "dropout" self.set_mlu() @@ -174,45 +214,12 @@ class TestDropoutOpInference2(TestDropoutOpInference): self.attrs = { 'dropout_prob': 0.75, 'is_test': True, - 'dropout_implementation': 'upscale_in_train' + 'dropout_implementation': 'upscale_in_train', } self.outputs = {'Out': self.inputs['X']} -class TestDropoutOpWithSeed(TestDropoutOp): - # the seed is a Tensor - def setUp(self): - self.op_type = "dropout" - self.set_mlu() - self.init_dtype() - self.inputs = { - "X": np.random.random((32, 64)).astype(self.dtype), - "Seed": np.asarray([125], dtype="int32") - } - self.attrs = { - 'dropout_prob': 0.0, - 'is_test': False, - 'dropout_implementation': 'upscale_in_train' - } - self.outputs = { - 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('uint8') - } - - -class TestDropoutOpFp16(TestDropoutOp): - # float16 - def init_dtype(self): - self.dtype = np.float16 - - def set_mlu(self): - self.__class__.use_mlu = True - self.place = paddle.device.MLUPlace(0) - self.__class__.no_need_check_grad = True - - class TestDropoutAPI(unittest.TestCase): - def setUp(self): np.random.seed(123) self.places = [fluid.CPUPlace(), paddle.device.MLUPlace(0)] @@ -220,43 +227,44 @@ class TestDropoutAPI(unittest.TestCase): def check_static_result(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data(name="input", shape=[40, 40], dtype="float32") - res1 = paddle.nn.functional.dropout(x=input, - p=0., - training=False, - mode='upscale_in_train') - res2 = paddle.nn.functional.dropout(x=input, - p=0., - axis=0, - training=True, - mode='upscale_in_train') - res3 = paddle.nn.functional.dropout(x=input, - p=0., - axis=0, - training=False, - mode='upscale_in_train') - res4 = paddle.nn.functional.dropout(x=input, - p=0., - axis=[0, 1], - training=True, - mode='upscale_in_train') - res5 = paddle.nn.functional.dropout(x=input, - p=0., - axis=[0, 1], - training=False, - mode='upscale_in_train') - res6 = paddle.nn.functional.dropout(x=input, - p=1., - training=True, - mode='upscale_in_train') + res1 = paddle.nn.functional.dropout( + x=input, p=0.0, training=False, mode='upscale_in_train' + ) + res2 = paddle.nn.functional.dropout( + x=input, p=0.0, axis=0, training=True, mode='upscale_in_train' + ) + res3 = paddle.nn.functional.dropout( + x=input, p=0.0, axis=0, training=False, mode='upscale_in_train' + ) + res4 = paddle.nn.functional.dropout( + x=input, + p=0.0, + axis=[0, 1], + training=True, + mode='upscale_in_train', + ) + res5 = paddle.nn.functional.dropout( + x=input, + p=0.0, + axis=[0, 1], + training=False, + mode='upscale_in_train', + ) + res6 = paddle.nn.functional.dropout( + x=input, p=1.0, training=True, mode='upscale_in_train' + ) res7 = paddle.fluid.layers.dropout( x=input, - dropout_prob=0., - dropout_implementation='upscale_in_train') - res8 = paddle.nn.functional.dropout(x=input, - p=0., - axis=(0, 1), - training=False, - mode='upscale_in_train') + dropout_prob=0.0, + dropout_implementation='upscale_in_train', + ) + res8 = paddle.nn.functional.dropout( + x=input, + p=0.0, + axis=(0, 1), + training=False, + mode='upscale_in_train', + ) in_np = np.random.random([40, 40]).astype("float32") res_np = in_np @@ -265,13 +273,17 @@ class TestDropoutAPI(unittest.TestCase): exe = fluid.Executor(place) res_list = [res1, res2, res3, res4, res5, res7, res8] for res in res_list: - fetches = exe.run(fluid.default_main_program(), - feed={"input": in_np}, - fetch_list=[res]) + fetches = exe.run( + fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res], + ) np.testing.assert_allclose(fetches[0], res_np) - fetches2 = exe.run(fluid.default_main_program(), - feed={"input": in_np}, - fetch_list=[res6]) + fetches2 = exe.run( + fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res6], + ) np.testing.assert_allclose(fetches2[0], res_np2) def test_static(self): diff --git a/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py index 032c2e9a506..df173ebf18c 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py @@ -28,12 +28,15 @@ def AffineGrid(theta, grid_shape): n = grid_shape[0] h = grid_shape[1] w = grid_shape[2] - h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w, - axis=0).T[:, :, np.newaxis] - w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h, - axis=0)[:, :, np.newaxis] - grid = np.concatenate([w_idx, h_idx, np.ones([h, w, 1])], - axis=2) # h * w * 3 + h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[ + :, :, np.newaxis + ] + w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[ + :, :, np.newaxis + ] + grid = np.concatenate( + [w_idx, h_idx, np.ones([h, w, 1])], axis=2 + ) # h * w * 3 grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3 ret = np.zeros([n, h * w, 2]) @@ -53,13 +56,17 @@ def getGridPointValue(data, x, y): out_H = x.shape[1] out_W = x.shape[2] - #out = np.zeros(data_shape, dtype='float32') + # out = np.zeros(data_shape, dtype='float32') out = np.zeros([N, C, out_H, out_W], dtype='float32') for i in range(N): for j in range(out_H): for k in range(out_W): - if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[ - i, j, k] < 0 or x[i, j, k] > in_W - 1: + if ( + y[i, j, k] < 0 + or y[i, j, k] > in_H - 1 + or x[i, j, k] < 0 + or x[i, j, k] > in_W - 1 + ): out[i, :, j, k] = 0 else: out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] @@ -75,27 +82,28 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): if align_corners: grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * max_val) else: - grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * - (max_val + 1)) - 0.5 + grid_slice = ( + 0.5 * ((grid_slice.astype('float32') + 1.0) * (max_val + 1)) - 0.5 + ) if padding_mode == "border": grid_slice = clip(grid_slice, 0, max_val) elif padding_mode == "reflection": double_range = 2 * max_val if align_corners else (max_val + 1) * 2 - grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + - 0.5) + grid_abs = ( + np.abs(grid_slice) if align_corners else np.abs(grid_slice + 0.5) + ) extra = grid_abs - np.floor(grid_abs / double_range) * double_range grid_slice = np.minimum(extra, double_range - extra) - grid_slice = grid_slice if align_corners else clip( - grid_slice - 0.5, 0, max_val) + grid_slice = ( + grid_slice if align_corners else clip(grid_slice - 0.5, 0, max_val) + ) return grid_slice -def GridSampler(data, - grid, - align_corners=True, - mode="bilinear", - padding_mode="zeros"): +def GridSampler( + data, grid, align_corners=True, mode="bilinear", padding_mode="zeros" +): dims = data.shape N = dims[0] in_C = dims[1] @@ -119,14 +127,18 @@ def GridSampler(data, y0 = np.floor(y).astype('int32') y1 = y0 + 1 - wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), - (1, in_C, 1, 1)) - wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), - (1, in_C, 1, 1)) - wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), - (1, in_C, 1, 1)) - wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), - (1, in_C, 1, 1)) + wa = np.tile( + ((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1) + ) + wb = np.tile( + ((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1) + ) + wc = np.tile( + ((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1) + ) + wd = np.tile( + ((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1) + ) va = getGridPointValue(data, x0, y0) vb = getGridPointValue(data, x0, y1) @@ -142,7 +154,6 @@ def GridSampler(data, class TestGridSamplerOp(OpTest): - def setUp(self): self.place = paddle.device.MLUPlace(0) self.__class__.use_mlu = True @@ -166,12 +177,12 @@ class TestGridSamplerOp(OpTest): 'use_cudnn': False, "align_corners": self.align_corners, "padding_mode": self.padding_mode, - "mode": self.mode + "mode": self.mode, } self.outputs = { - 'Output': - GridSampler(x, grid, self.align_corners, self.mode, - self.padding_mode) + 'Output': GridSampler( + x, grid, self.align_corners, self.mode, self.padding_mode + ) } def test_check_output(self): @@ -186,20 +197,17 @@ class TestGridSamplerOp(OpTest): self.mode = "bilinear" -# TODO(fwg): Test this case when cnnl support align_corners = True. -# class Case1(TestGridSamplerOp): -# -# def initTestCase(self): -# self.x_shape = (2, 3, 5, 6) -# self.grid_shape = (2, 8, 9, 2) -# self.theta_shape = (2, 2, 3) -# self.align_corners = True -# self.padding_mode = "zeros" -# self.mode = "bilinear" +class Case1(TestGridSamplerOp): + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" class LargeInputCase(TestGridSamplerOp): - def initTestCase(self): self.x_shape = (2, 3, 128, 128) self.grid_shape = (2, 130, 130, 2) @@ -209,16 +217,15 @@ class LargeInputCase(TestGridSamplerOp): self.mode = "bilinear" -# TODO(fwg): Test this case when cnnl support align_corners = True. -# class Case2(LargeInputCase): -# -# def initTestCase(self): -# self.x_shape = (2, 3, 128, 128) -# self.grid_shape = (2, 130, 130, 2) -# self.theta_shape = (2, 2, 3) -# self.align_corners = True -# self.padding_mode = "zeros" -# self.mode = "bilinear" +class Case2(LargeInputCase): + def initTestCase(self): + self.x_shape = (2, 3, 128, 128) + self.grid_shape = (2, 130, 130, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py new file mode 100644 index 00000000000..2f2c10be7b6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + +paddle.enable_static() + + +def huber_loss_forward(val, delta): + abs_val = abs(val) + if abs_val <= delta: + return 0.5 * val * val + else: + return delta * (abs_val - 0.5 * delta) + + +class TestHuberLossOp(OpTest): + def setUp(self): + self.op_type = 'huber_loss' + self.set_mlu() + self.python_api = paddle.fluid.layers.huber_loss + self.python_out_sig = ["Out"] + self.delta = 1.0 + self.init_input() + shape = self.set_shape() + residual = self.inputs['Y'] - self.inputs['X'] + loss = np.vectorize(huber_loss_forward)(residual, self.delta).astype( + 'float32' + ) + self.attrs = {'delta': self.delta} + self.outputs = {'Residual': residual, 'Out': loss.reshape(shape)} + + def init_input(self): + shape = self.set_shape() + self.inputs = { + 'X': np.random.uniform(0, 1.0, shape).astype('float32'), + 'Y': np.random.uniform(0, 1.0, shape).astype('float32'), + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def set_shape(self): + return (100, 1) + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad_with_place( + self.place, + ['Y'], + 'Out', + max_relative_error=0.008, + no_grad_set=set("residual"), + ) + + def test_check_grad_ingore_y(self): + self.check_grad_with_place( + self.place, + ['X'], + 'Out', + max_relative_error=0.008, + no_grad_set=set('residual'), + ) + + +def TestHuberLossOp1(TestHuberLossOp): + def set_shape(self): + return 64 + + +def TestHuberLossOp2(TestHuberLossOp): + def set_shape(self): + return (6, 6) + + +def TestHuberLossOp3(TestHuberLossOp): + def set_shape(self): + return (6, 6, 1) + + +class TestHuberLossOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # the input and label must be Variable + xw = np.random.random((6, 6)).astype("float32") + xr = fluid.data(name='xr', shape=[None, 6], dtype="float32") + lw = np.random.random((6, 6)).astype("float32") + lr = fluid.data(name='lr', shape=[None, 6], dtype="float32") + delta = 1.0 + self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw, delta) + self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr, delta) + + # the dtype of input and label must be float32 or float64 + xw2 = fluid.data(name='xw2', shape=[None, 6], dtype="int32") + lw2 = fluid.data(name='lw2', shape=[None, 6], dtype="int32") + self.assertRaises( + TypeError, fluid.layers.huber_loss, xw2, lr, delta + ) + self.assertRaises( + TypeError, fluid.layers.huber_loss, xr, lw2, delta + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py new file mode 100644 index 00000000000..242e1c8e663 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_merged_adam_op_mlu.py @@ -0,0 +1,228 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +sys.path.append('..') +import unittest +import paddle +import numpy as np +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid.framework import in_dygraph_mode + + +def run_adam_op( + params, + grads, + lrs, + moment1s, + moment2s, + beta1_pows, + beta2_pows, + master_params, + epsilon, + beta1, + beta2, + place, + multi_precision=False, + use_merged=False, +): + assert len(params) == len(grads) + assert len(params) == len(lrs) + assert len(params) == len(moment1s) + assert len(params) == len(moment2s) + assert len(params) == len(beta1_pows) + assert len(params) == len(beta1_pows) + assert len(params) == len(master_params) + paddle.disable_static() + # paddle.set_device(place) + + param_vars = [paddle.fluid.dygraph.to_variable(p) for p in params] + grad_vars = [paddle.fluid.dygraph.to_variable(g) for g in grads] + lr_vars = [paddle.fluid.dygraph.to_variable(l) for l in lrs] + moment1_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment1s] + moment2_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment2s] + beta1_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta1_pows] + beta2_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta2_pows] + master_param_vars = [ + paddle.fluid.dygraph.to_variable(m_p) for m_p in master_params + ] + + if not use_merged: + for i in range(len(param_vars)): + _, _, _, _, _, _ = _legacy_C_ops.adam( + param_vars[i], + grad_vars[i], + lr_vars[i], + moment1_vars[i], + moment2_vars[i], + beta1_pow_vars[i], + beta2_pow_vars[i], + master_param_vars[i], + param_vars[i], + moment1_vars[i], + moment2_vars[i], + beta1_pow_vars[i], + beta2_pow_vars[i], + master_param_vars[i], + 'epsilon', + epsilon, + 'beta1', + beta1, + 'beta2', + beta2, + 'multi_precision', + multi_precision, + ) + else: + if in_dygraph_mode(): + _, _, _, _, _, _ = _C_ops.merged_adam_( + param_vars, + grad_vars, + lr_vars, + moment1_vars, + moment2_vars, + beta1_pow_vars, + beta2_pow_vars, + master_param_vars, + beta1, + beta2, + epsilon, + multi_precision, + False, + ) + else: + _, _, _, _, _, _ = _legacy_C_ops.merged_adam( + param_vars, + grad_vars, + lr_vars, + moment1_vars, + moment2_vars, + beta1_pow_vars, + beta2_pow_vars, + master_param_vars, + param_vars, + moment1_vars, + moment2_vars, + beta1_pow_vars, + beta2_pow_vars, + master_param_vars, + 'epsilon', + epsilon, + 'beta1', + beta1, + 'beta2', + beta2, + 'multi_precision', + multi_precision, + ) + + outputs = { + 'ParamOut': param_vars, + 'Moment1Out': moment1_vars, + 'Moment2Out': moment2_vars, + 'Beta1PowOut': beta1_pow_vars, + 'Beta2PowOut': beta2_pow_vars, + 'MasterParamOut': master_param_vars, + } + + return outputs + + +class TestMergedAdam(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + # dtype = np.float16 if multi_precision and place == 'mlu' else np.float32 + dtype = np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + moment1s = self.gen_rand_data(shapes, mp_dtype) + moment2s = self.gen_rand_data(shapes, mp_dtype) + beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + master_params = [p.astype(mp_dtype) for p in params] + return ( + params, + grads, + lrs, + moment1s, + moment2s, + beta1_pows, + beta2_pows, + master_params, + ) + + def check_with_place(self, place, multi_precision): + ( + params, + grads, + lrs, + moment1s, + moment2s, + beta1_pows, + beta2_pows, + master_params, + ) = self.prepare_data(self.shapes, multi_precision, self.seed, place) + + def run_op(use_merged): + return run_adam_op( + params=params, + grads=grads, + lrs=lrs, + moment1s=moment1s, + moment2s=moment2s, + beta1_pows=beta1_pows, + beta2_pows=beta2_pows, + master_params=master_params, + epsilon=0.9, + beta1=0.9, + beta2=0.99, + place=place, + multi_precision=multi_precision, + use_merged=use_merged, + ) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + + for key in outs1.keys(): + value1 = outs1[key] + value2 = outs2[key] + for i in range(len(value1)): + if place == 'mlu': + np.testing.assert_array_equal(value1[i], value2[i]) + else: + np.testing.assert_allclose( + value1[i], value2[i], rtol=1e-05, atol=1e-07 + ) + + def test_main(self): + for multi_precision in [False, True]: + self.check_with_place(self.place, multi_precision) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py new file mode 100644 index 00000000000..68df3067f0c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_prior_box_op_mlu.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys + +sys.path.append('..') +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid +import paddle +import math + +paddle.enable_static() + + +class TestMLUPriorBox(OpTest): + def setUp(self): + self.op_type = "prior_box" + self.set_mlu() + self.init_dtype() + self.set_data() + + def test_check_output(self): + self.check_output_with_place(self.place) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset, + } + if len(self.max_sizes) > 0: + self.attrs['max_sizes'] = self.max_sizes + + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def set_max_sizes(self): + max_sizes = [5, 10] + self.max_sizes = np.array(max_sizes).astype('float32').tolist() + + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = True + + def init_test_params(self): + self.layer_w = 32 + self.layer_h = 32 + + self.image_w = 40 + self.image_h = 40 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() + self.set_max_sizes() + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.set_min_max_aspect_ratios_order() + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float64 + ).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float64).flatten() + + self.clip = True + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 0: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, self.image_h) + ).astype('float32') + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, self.layer_h) + ).astype('float32') + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype('float32') + out_var = np.zeros(out_dim).astype('float32') + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + idx = 0 + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + if not self.min_max_aspect_ratios_order: + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h, + ] + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h, + ] + idx += 1 + else: + c_w = c_h = min_size / 2.0 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h, + ] + idx += 1 + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h, + ] + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if abs(ar - 1.0) < 1e-6: + continue + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, + (c_y - c_h) / self.image_h, + (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h, + ] + idx += 1 + + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + # set the variance. + out_var = np.tile( + self.variances, (self.layer_h, self.layer_w, self.num_priors, 1) + ) + self.out_boxes = out_boxes.astype('float32') + self.out_var = out_var.astype('float32') + + +class TestMLUPriorBoxWithoutMaxSize(TestMLUPriorBox): + def set_max_sizes(self): + self.max_sizes = [] + + +class TestMLUPriorBoxWithoutSpecifiedOutOrder(TestMLUPriorBox): + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = False + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py index ab984187443..af0882be46a 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_reduce_sum_op_mlu.py @@ -26,7 +26,6 @@ paddle.enable_static() class TestMLUReduceSumOp(OpTest): - def setUp(self): self.init_op_type() self.initTestCase() @@ -34,16 +33,16 @@ class TestMLUReduceSumOp(OpTest): self.attrs = { 'dim': self.axis, 'keep_dim': self.keep_dim, - 'reduce_all': self.reduce_all + 'reduce_all': self.reduce_all, } self.inputs = {'X': np.random.random(self.shape).astype("float32")} if self.attrs['reduce_all']: self.outputs = {'Out': self.inputs['X'].sum()} else: self.outputs = { - 'Out': - self.inputs['X'].sum(axis=self.axis, - keepdims=self.attrs['keep_dim']) + 'Out': self.inputs['X'].sum( + axis=self.axis, keepdims=self.attrs['keep_dim'] + ) } def set_mlu(self): @@ -64,100 +63,92 @@ class TestMLUReduceSumOp(OpTest): def initTestCase(self): self.shape = (5, 6, 10) - self.axis = (0, ) + self.axis = (0,) class TestSumOp5D(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (1, 2, 5, 6, 10) - self.axis = (0, ) + self.axis = (0,) class TestSumOp6D(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (1, 1, 2, 5, 6, 10) - self.axis = (0, ) + self.axis = (0,) class TestSumOp8D(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (1, 3, 1, 2, 1, 4, 3, 10) self.axis = (0, 3) class Test1DReduce(TestMLUReduceSumOp): - def initTestCase(self): self.shape = 120 - self.axis = (0, ) + self.axis = (0,) class Test2DReduce0(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (20, 10) - self.axis = (0, ) + self.axis = (0,) class Test2DReduce1(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (20, 10) - self.axis = (1, ) + self.axis = (1,) class Test3DReduce0(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (5, 6, 7) - self.axis = (1, ) + self.axis = (1,) class Test3DReduce1(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (5, 6, 7) - self.axis = (2, ) + self.axis = (2,) class Test3DReduce2(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (5, 6, 7) - self.axis = (-2, ) + self.axis = (-2,) class Test3DReduce3(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (5, 6, 7) self.axis = (1, 2) class TestKeepDimReduce(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (5, 6, 10) - self.axis = (1, ) + self.axis = (1,) self.keep_dim = True class TestKeepDim8DReduce(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (2, 5, 3, 2, 2, 3, 4, 2) self.axis = (3, 4, 5) self.keep_dim = True + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.03 + ) + class TestReduceAll(TestMLUReduceSumOp): - def initTestCase(self): self.shape = (5, 6, 2, 10) - self.axis = (0, ) + self.axis = (0,) self.reduce_all = True diff --git a/python/paddle/fluid/tests/unittests/mlu/test_slice_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_slice_op_mlu.py index 71116b4d3ce..3d61fd3fc1f 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_slice_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_slice_op_mlu.py @@ -31,7 +31,6 @@ paddle.enable_static() # Situation 1: starts(list, no tensor), ends(list, no tensor) # 1.1 without attr(decrease) class TestSliceOp(OpTest): - def setUp(self): self.op_type = "slice" self.set_mlu() @@ -42,7 +41,7 @@ class TestSliceOp(OpTest): 'axes': self.axes, 'starts': self.starts, 'ends': self.ends, - 'infer_flags': self.infer_flags + 'infer_flags': self.infer_flags, } def config(self): @@ -57,9 +56,9 @@ class TestSliceOp(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) def set_mlu(self): self.__class__.use_mlu = True @@ -67,7 +66,6 @@ class TestSliceOp(OpTest): class TestCase1(TestSliceOp): - def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [-3, 0, 2] @@ -78,7 +76,6 @@ class TestCase1(TestSliceOp): class TestCase2(TestSliceOp): - def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [-3, 0, 2] @@ -90,7 +87,6 @@ class TestCase2(TestSliceOp): # 1.2 with attr(decrease) class TestSliceOp_decs_dim(OpTest): - def setUp(self): self.op_type = "slice" self.set_mlu() @@ -118,9 +114,9 @@ class TestSliceOp_decs_dim(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) def set_mlu(self): self.__class__.use_mlu = True @@ -128,7 +124,6 @@ class TestSliceOp_decs_dim(OpTest): class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): - def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [1, 0, 2] @@ -140,7 +135,6 @@ class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim): - def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [-1, 0, 2] @@ -152,7 +146,6 @@ class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim): - def config(self): self.input = np.random.random([3, 4, 5, 7]).astype("float32") self.starts = [0, 1, 2, 3] @@ -164,7 +157,6 @@ class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): - def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [-1] @@ -176,7 +168,6 @@ class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): - def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [0, 1, 2, 3] @@ -190,7 +181,6 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): # Situation 2: starts(list, have tensor), ends(list, no tensor) # without attr(decrease) class TestSliceOp_starts_ListTensor(OpTest): - def setUp(self): self.op_type = "slice" self.set_mlu() @@ -198,8 +188,9 @@ class TestSliceOp_starts_ListTensor(OpTest): starts_tensor = [] for index, ele in enumerate(self.starts): - starts_tensor.append(("x" + str(index), np.ones( - (1)).astype('int64') * ele)) + starts_tensor.append( + ("x" + str(index), np.ones((1)).astype('int64') * ele) + ) self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} self.outputs = {'Out': self.out} @@ -207,7 +198,7 @@ class TestSliceOp_starts_ListTensor(OpTest): 'axes': self.axes, 'starts': self.starts_infer, 'ends': self.ends, - 'infer_flags': self.infer_flags + 'infer_flags': self.infer_flags, } def config(self): @@ -224,9 +215,9 @@ class TestSliceOp_starts_ListTensor(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) def set_mlu(self): self.__class__.use_mlu = True @@ -236,7 +227,6 @@ class TestSliceOp_starts_ListTensor(OpTest): # Situation 2: starts(list, have tensor), ends(list, no tensor) # with attr(decrease) class TestSliceOp_decs_dim_starts_ListTensor(OpTest): - def setUp(self): self.op_type = "slice" self.set_mlu() @@ -244,8 +234,9 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): starts_tensor = [] for index, ele in enumerate(self.starts): - starts_tensor.append(("x" + str(index), np.ones( - (1)).astype('int32') * ele)) + starts_tensor.append( + ("x" + str(index), np.ones((1)).astype('int32') * ele) + ) self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} @@ -273,9 +264,9 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) def set_mlu(self): self.__class__.use_mlu = True @@ -283,8 +274,8 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): class TestSliceOp_decs_dim_5_starts_ListTensor( - TestSliceOp_decs_dim_starts_ListTensor): - + TestSliceOp_decs_dim_starts_ListTensor +): def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.starts = [-1] @@ -300,7 +291,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor( # Situation 3: starts(tensor), ends(list, no tensor) # with attr(decrease) class TestSliceOp_decs_dim_starts_OneTensor(OpTest): - def setUp(self): self.op_type = "slice" self.__class__.use_mlu = True @@ -308,7 +298,7 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): self.config() self.inputs = { 'Input': self.input, - "StartsTensor": np.array(self.starts, dtype="int32") + "StartsTensor": np.array(self.starts, dtype="int32"), } self.outputs = {'Out': self.out} self.attrs = { @@ -332,15 +322,14 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) # Situation 4: starts(tensor), ends(tensor) # without attr(decrease) class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): - def setUp(self): self.op_type = "slice" self.__class__.use_mlu = True @@ -350,14 +339,14 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): self.inputs = { 'Input': self.input, "StartsTensor": np.array(self.starts, dtype="int64"), - "EndsTensor": np.array(self.ends, dtype="int32") + "EndsTensor": np.array(self.ends, dtype="int32"), } self.outputs = {'Out': self.out} self.attrs = { 'axes': self.axes, #'starts': self.starts, #'ends': self.ends_infer, - 'infer_flags': self.infer_flags + 'infer_flags': self.infer_flags, } def config(self): @@ -372,15 +361,14 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) # Situation 5: starts(tensor), ends(tensor) # with attr(decrease) class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): - def setUp(self): self.op_type = "slice" self.__class__.use_mlu = True @@ -389,7 +377,7 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): self.inputs = { 'Input': self.input, "StartsTensor": np.array(self.starts, dtype="int32"), - "EndsTensor": np.array(self.ends, dtype="int32") + "EndsTensor": np.array(self.ends, dtype="int32"), } self.outputs = {'Out': self.out} self.attrs = { @@ -413,15 +401,14 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) # Situation 6: starts(tensor), ends(list, have tensor) # without attr(decrease) class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): - def setUp(self): self.op_type = "slice" self.__class__.use_mlu = True @@ -430,20 +417,21 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): ends_tensor = [] for index, ele in enumerate(self.ends): - ends_tensor.append(("y" + str(index), np.ones( - (1)).astype('int32') * ele)) + ends_tensor.append( + ("y" + str(index), np.ones((1)).astype('int32') * ele) + ) self.inputs = { 'Input': self.input, "StartsTensor": np.array(self.starts, dtype="int32"), - 'EndsTensorList': ends_tensor + 'EndsTensorList': ends_tensor, } self.outputs = {'Out': self.out} self.attrs = { 'axes': self.axes, #'starts': self.starts, 'ends': self.ends_infer, - 'infer_flags': self.infer_flags + 'infer_flags': self.infer_flags, } def config(self): @@ -460,14 +448,13 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): self.check_output_with_place(self.place) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) # Test float16 class TestFP16(OpTest): - def setUp(self): self.op_type = "slice" self.__class__.use_mlu = True @@ -479,7 +466,7 @@ class TestFP16(OpTest): 'axes': self.axes, 'starts': self.starts, 'ends': self.ends, - 'infer_flags': self.infer_flags + 'infer_flags': self.infer_flags, } def config(self): @@ -495,13 +482,12 @@ class TestFP16(OpTest): self.check_output_with_place(self.place, atol=1e-5) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006) + self.check_grad_with_place( + self.place, ['Input'], 'Out', max_relative_error=0.006 + ) class TestFP16_2(OpTest): - def setUp(self): self.op_type = "slice" self.__class__.use_mlu = True @@ -513,7 +499,7 @@ class TestFP16_2(OpTest): 'axes': self.axes, 'starts': self.starts, 'ends': self.ends, - 'infer_flags': self.infer_flags + 'infer_flags': self.infer_flags, } def config(self): @@ -529,24 +515,28 @@ class TestFP16_2(OpTest): self.check_output_with_place(self.place, atol=1e-5) def test_check_grad_normal(self): - self.check_grad_with_place(self.place, ['Input'], - 'Out', - max_relative_error=0.006, - numeric_grad_delta=0.5) + self.check_grad_with_place( + self.place, + ['Input'], + 'Out', + max_relative_error=0.006, + numeric_grad_delta=0.5, + ) class TestSliceApiWithTensor(unittest.TestCase): - def test_starts_ends_is_tensor(self): with paddle.fluid.dygraph.guard(): a = paddle.rand(shape=[4, 5, 6], dtype='float32') axes = [0, 1, 2] starts = [-3, 0, 2] ends = [3, 2, 4] - a_1 = paddle.slice(a, - axes=axes, - starts=paddle.to_tensor(starts, dtype='int32'), - ends=paddle.to_tensor(ends, dtype='int32')) + a_1 = paddle.slice( + a, + axes=axes, + starts=paddle.to_tensor(starts, dtype='int32'), + ends=paddle.to_tensor(ends, dtype='int32'), + ) a_2 = paddle.slice(a, axes=axes, starts=starts, ends=ends) np.testing.assert_allclose(a_1.numpy(), a_2.numpy()) @@ -569,24 +559,22 @@ class TestSliceApiWithTensor(unittest.TestCase): class TestImperativeVarBaseGetItem(unittest.TestCase): - def test_getitem_with_long(self): with fluid.dygraph.guard(): data = np.random.random((2, 80, 16128)).astype('float32') var = fluid.dygraph.to_variable(data) - sliced = var[:, 10:, :var.shape[1]] # var.shape[1] is 80L here + sliced = var[:, 10:, : var.shape[1]] # var.shape[1] is 80L here self.assertEqual(sliced.shape, [2, 70, 80]) - sliced = var[:, var.shape[0]:, var.shape[0]:var.shape[1]] + sliced = var[:, var.shape[0] :, var.shape[0] : var.shape[1]] self.assertEqual(sliced.shape, [2, 78, 78]) def test_getitem_with_float(self): - def test_float_in_slice_item(): with fluid.dygraph.guard(): data = np.random.random((2, 80, 16128)).astype('float32') var = fluid.dygraph.to_variable(data) - sliced = var[:, 1.1:, :var.shape[1]] + sliced = var[:, 1.1:, : var.shape[1]] self.assertRaises(Exception, test_float_in_slice_item) @@ -600,15 +588,6 @@ class TestImperativeVarBaseGetItem(unittest.TestCase): class TestInferShape(unittest.TestCase): - - def test(self): - x = paddle.ones(shape=[3, 4, 5]) - x.desc.set_shape([3, -1, 5]) - self.assertEqual(x.shape, (3, -1, 5)) - - out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3]) - self.assertEqual(out0.shape, (3, 3, 5)) - def test_axis_less_than_zero(self): # Using paddle.disable_static will make other unittests fail. @@ -616,13 +595,18 @@ class TestInferShape(unittest.TestCase): x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4]) x = paddle.to_tensor(x_arr) - pp_slice = paddle.slice(x, [ - 100, - ], [0], [1]) + pp_slice = paddle.slice( + x, + [ + 100, + ], + [0], + [1], + ) np_slice = x_arr[:, :, 0:1] np.testing.assert_allclose(pp_slice, np_slice) - pp_slice = paddle.slice(x, (-100, ), [0], [1]) + pp_slice = paddle.slice(x, (-100,), [0], [1]) np_slice = x_arr[0:1] np.testing.assert_allclose(pp_slice, np_slice) @@ -630,9 +614,11 @@ class TestInferShape(unittest.TestCase): x = paddle.to_tensor(np.reshape(x_arr, (0, 0, 0))) starts = paddle.to_tensor( - np.reshape(np.array([], dtype=np.int32), (0, ))) + np.reshape(np.array([], dtype=np.int32), (0,)) + ) ends = paddle.to_tensor( - np.reshape(np.array([], dtype=np.int32), (0, ))) + np.reshape(np.array([], dtype=np.int32), (0,)) + ) with self.assertRaises(ValueError): paddle.slice(x, [-1000000], starts, ends) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py index 25dbbbd028e..8ef5c5dc5df 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py @@ -30,7 +30,6 @@ SEED = 2021 class TestSoftmaxWithCrossEntropyOp(OpTest): - def set_mlu(self): self.__class__.use_mlu = True @@ -53,8 +52,10 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.initParams() logits = getattr( - self, "logits", - np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + self, + "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype), + ) softmax = np.apply_along_axis(stable_softmax, self.axis, logits) if self.soft_label: @@ -65,8 +66,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.shape[self.axis] = 1 labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") - loss = cross_entropy(softmax, labels, self.soft_label, self.axis, - self.ignore_index) + loss = cross_entropy( + softmax, labels, self.soft_label, self.axis, self.ignore_index + ) one_hot_label = np.eye(axis_dim)[labels.reshape(-1)] @@ -74,7 +76,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.outputs = { "Backprop": (softmax - one_hot_label).astype(self.dtype), "Softmax": softmax.astype(self.dtype), - "Loss": loss.astype(self.dtype) + "Loss": loss.astype(self.dtype), } self.attrs = { "numeric_stable_mode": self.numeric_stable_mode, @@ -92,14 +94,16 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): if self.dtype == np.float16: return # fp32 has low precision, cpu and mlu both need to relax the max_relative_error if using fp32 - self.check_grad_with_place(self.place, ['Logits'], - 'Loss', - numeric_grad_delta=0.001, - max_relative_error=0.5) + self.check_grad_with_place( + self.place, + ['Logits'], + 'Loss', + numeric_grad_delta=0.001, + max_relative_error=0.5, + ) class TestPowNet(unittest.TestCase): - def _test(self, run_mlu=True): main_prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -114,9 +118,9 @@ class TestPowNet(unittest.TestCase): with paddle.static.program_guard(main_prog, startup_prog): a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') - label = paddle.static.data(name="label", - shape=[32, 1], - dtype='int64') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64' + ) sum = paddle.add(a, b) z = paddle.pow(sum, 2.0) @@ -140,16 +144,17 @@ class TestPowNet(unittest.TestCase): print("Start run on {}".format(place)) for epoch in range(100): - pred_res, loss_res = exe.run(main_prog, - feed={ - "a": a_np, - "b": b_np, - "label": label_np - }, - fetch_list=[prediction, loss]) + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, "b": b_np, "label": label_np}, + fetch_list=[prediction, loss], + ) if epoch % 10 == 0: - print("Epoch {} | Prediction[0]: {}, Loss: {}".format( - epoch, pred_res[0], loss_res)) + print( + "Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res + ) + ) return pred_res, loss_res @@ -157,7 +162,7 @@ class TestPowNet(unittest.TestCase): cpu_pred, cpu_loss = self._test(False) mlu_pred, mlu_loss = self._test(True) - np.testing.assert_allclose(mlu_pred, cpu_pred, rtol=1e-5) + np.testing.assert_allclose(mlu_pred, cpu_pred, rtol=2e-5) np.testing.assert_allclose(mlu_loss, cpu_loss) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py index 3b8dd2c1922..a2b59048462 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_base_mlu.py @@ -44,17 +44,19 @@ SEED = 10 class TestSyncBatchNormRunnerBase(object): - - def get_model(self, - main, - startup, - place, - layout, - seed, - sync_bn=False, - only_forward=False): + def get_model( + self, + main, + startup, + place, + layout, + seed, + sync_bn=False, + only_forward=False, + ): raise NotImplementedError( - "get model should be implemented by child class.") + "get model should be implemented by child class." + ) def wait_server_ready(self, endpoints): assert not isinstance(endpoints, string_types) @@ -63,13 +65,15 @@ class TestSyncBatchNormRunnerBase(object): not_ready_endpoints = [] for ep in endpoints: ip_port = ep.split(":") - with closing(socket.socket(socket.AF_INET, - socket.SOCK_STREAM)) as sock: + with closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as sock: sock.settimeout(2) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, - 1) + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT, 1 + ) result = sock.connect_ex((ip_port[0], int(ip_port[1]))) if result != 0: @@ -77,39 +81,47 @@ class TestSyncBatchNormRunnerBase(object): not_ready_endpoints.append(ep) if not all_ok: sys.stderr.write("server not ready, wait 3 sec to retry...\n") - sys.stderr.write("not ready endpoints:" + - str(not_ready_endpoints) + "\n") + sys.stderr.write( + "not ready endpoints:" + str(not_ready_endpoints) + "\n" + ) sys.stderr.flush() time.sleep(3) else: break - def initCommunicator(self, program, rank, nranks, wait_port, - current_endpoint, endpoints): + def initCommunicator( + self, program, rank, nranks, wait_port, current_endpoint, endpoints + ): other_endpoints = endpoints[:] other_endpoints.remove(current_endpoint) if rank == 0 and wait_port: self.wait_server_ready(other_endpoints) block = program.global_block() - cncl_id_var = block.create_var(name=nameGen.generate('cncl_id'), - persistable=True, - type=core.VarDesc.VarType.RAW) - block.append_op(type='c_gen_cncl_id', - inputs={}, - outputs={'Out': cncl_id_var}, - attrs={ - 'rank': rank, - 'endpoint': current_endpoint, - 'other_endpoints': other_endpoints - }) - block.append_op(type='c_comm_init', - inputs={'X': cncl_id_var}, - outputs={}, - attrs={ - 'nranks': nranks, - 'rank': rank, - 'ring_id': self.global_ring_id - }) + cncl_id_var = block.create_var( + name=nameGen.generate('cncl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + block.append_op( + type='c_gen_cncl_id', + inputs={}, + outputs={'Out': cncl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': cncl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': self.global_ring_id, + }, + ) def run_trainer(self, args): device_id = int(os.getenv("FLAGS_selected_mlus", "0")) @@ -127,8 +139,8 @@ class TestSyncBatchNormRunnerBase(object): self._compare(args, place, layout, True) # Test FP16 - @TODO - self.dtype = np.float16 - self.atol = 1e-2 + self.bn_dtype = np.float16 + self.atol = 3e-3 # Test training for place in places: @@ -142,24 +154,30 @@ class TestSyncBatchNormRunnerBase(object): sys.stdout.buffer.write( pickle.dumps( - 'training, inference, fp32, fp16, NCHW, NHWC all passed')) + 'training, inference, fp32, fp16, NCHW, NHWC all passed' + ) + ) def _compare(self, args, place, layout, only_forward): scope = core.Scope() np.random.seed(SEED) - data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 + data = np.random.random(size=self.dshape).astype(self.dtype) * 4.0 - 2 sys.stderr.write("data: " + str(data) + "\n") - data = create_or_get_tensor(scope, "input", - OpTest.np_dtype_to_fluid_dtype(data), place) + data = create_or_get_tensor( + scope, "input", OpTest.np_dtype_to_fluid_dtype(data), place + ) - bn_fetches = self._cal_single_card(args, data, place, layout, - only_forward) + bn_fetches = self._cal_single_card( + args, data, place, layout, only_forward + ) fetch_names, sync_bn_fetches = self._cal_multiple_cards( - args, data, place, layout, only_forward) + args, data, place, layout, only_forward + ) - sys.stderr.write("len(sync_bn_fetches): " + str(len(sync_bn_fetches)) + - "\n") + sys.stderr.write( + "len(sync_bn_fetches): " + str(len(sync_bn_fetches)) + "\n" + ) for i in six.moves.xrange(0, len(sync_bn_fetches)): sys.stderr.write("i: " + str(i) + "\n") sys.stderr.write("fetch_names[i]): " + fetch_names[i] + "\n") @@ -167,13 +185,14 @@ class TestSyncBatchNormRunnerBase(object): bn_val = bn_fetches[i] sync_bn_val = sync_bn_fetches[i] if sync_bn_val.shape != bn_val.shape: - sync_bn_val = sync_bn_val[:bn_val.shape[0]] + sync_bn_val = sync_bn_val[: bn_val.shape[0]] # i = 0 if fetch_names[i] == 'reduce_sum_0.tmp_0': # sys.stderr.write("skip reduce_sum_0.tmp_0 (Out of reduce_sum op)" + "\n") - sys.stderr.write("reduce_sum_0.tmp_0 (Out of reduce_sum op)" + - "\n") + sys.stderr.write( + "reduce_sum_0.tmp_0 (Out of reduce_sum op)" + "\n" + ) sys.stderr.write("bn_val: " + str(bn_val) + "\n") sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") @@ -201,7 +220,8 @@ class TestSyncBatchNormRunnerBase(object): if fetch_names[i] == 'batch_norm_0.tmp_2': # sys.stderr.write("skip batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n") sys.stderr.write( - "batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n") + "batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n" + ) sys.stderr.write("bn_val: " + str(bn_val) + "\n") sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") @@ -234,8 +254,9 @@ class TestSyncBatchNormRunnerBase(object): # i = 8 if fetch_names[i] == 'batch_norm_0.tmp_1': - sys.stderr.write("skip batch_norm_0.tmp_1 (SavedVariance)" + - "\n") + sys.stderr.write( + "skip batch_norm_0.tmp_1 (SavedVariance)" + "\n" + ) sys.stderr.write("bn_val: " + str(bn_val) + "\n") sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") @@ -281,10 +302,16 @@ class TestSyncBatchNormRunnerBase(object): if fetch_names[i] == 'conv2d_0.tmp_0@GRAD': atol = 1e-2 - assert np.allclose( - bn_val, sync_bn_val, atol=atol), "Output (" + fetch_names[ - i] + ") has diff. \n" + "\nBN " + str( - bn_val) + "\n" + "Sync BN " + str(sync_bn_val) + assert np.allclose(bn_val, sync_bn_val, atol=atol), ( + "Output (" + + fetch_names[i] + + ") has diff. \n" + + "\nBN " + + str(bn_val) + + "\n" + + "Sync BN " + + str(sync_bn_val) + ) def _cal_single_card(self, args, data, place, layout, only_forward): # Single-MLU, N = 32 per MLU @@ -294,23 +321,31 @@ class TestSyncBatchNormRunnerBase(object): startup_prog.global_seed(SEED) paddle.seed(SEED) - outs = self.get_model(train_prog, startup_prog, place, layout, SEED, - False, only_forward) + outs = self.get_model( + train_prog, startup_prog, place, layout, SEED, False, only_forward + ) exe = fluid.Executor(place) exe.run(startup_prog) fetch_names = [v.name for v in outs] + [ - 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + 'bn_moving_mean', + 'bn_moving_variance', + 'bn_scale', + 'bn_bias', ] if not only_forward: others = [ - 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', - 'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD' + 'batch_norm_0.tmp_0', + 'batch_norm_0.tmp_1', + 'bn_scale@GRAD', + 'bn_bias@GRAD', + 'batch_norm_0.tmp_3@GRAD', + 'conv2d_0.tmp_0@GRAD', ] fetch_names += others - bn_fetches = exe.run(program=train_prog, - feed={'input': data}, - fetch_list=fetch_names) + bn_fetches = exe.run( + program=train_prog, feed={'input': data}, fetch_list=fetch_names + ) return bn_fetches @@ -331,8 +366,9 @@ class TestSyncBatchNormRunnerBase(object): current_endpoint = args["currentendpoint"] nranks = 2 - self.initCommunicator(startup_prog, rank, nranks, True, - current_endpoint, endpoints) + self.initCommunicator( + startup_prog, rank, nranks, True, current_endpoint, endpoints + ) # sys.stderr.write("after init, startup_prog: " + # startup_prog.to_string(True) + "\n") train_prog.global_seed(SEED) @@ -342,8 +378,9 @@ class TestSyncBatchNormRunnerBase(object): paddle.seed(SEED) self.rank = rank - outs = self.get_model(train_prog, startup_prog, place, layout, SEED, - True, only_forward) + outs = self.get_model( + train_prog, startup_prog, place, layout, SEED, True, only_forward + ) # sys.stderr.write("after get_model, train_prog: " + # train_prog.to_string(True) + "\n") # sys.stderr.write("after get_model, startup_prog: " + @@ -366,17 +403,24 @@ class TestSyncBatchNormRunnerBase(object): exe = fluid.Executor(place) exe.run(startup_prog) fetch_names = [v.name for v in outs] + [ - 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + 'bn_moving_mean', + 'bn_moving_variance', + 'bn_scale', + 'bn_bias', ] if not only_forward: others = [ - 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', - 'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD' + 'batch_norm_0.tmp_0', + 'batch_norm_0.tmp_1', + 'bn_scale@GRAD', + 'bn_bias@GRAD', + 'batch_norm_0.tmp_3@GRAD', + 'conv2d_0.tmp_0@GRAD', ] fetch_names += others - sync_bn_fetches = exe.run(program=train_prog, - feed={'input': data}, - fetch_list=fetch_names) + sync_bn_fetches = exe.run( + program=train_prog, feed={'input': data}, fetch_list=fetch_names + ) return fetch_names, sync_bn_fetches @@ -399,19 +443,20 @@ from contextlib import closing class TestDistBase(unittest.TestCase): - def setUp(self): self._port_set = set() self._trainers = 2 self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( - self._find_free_port(), self._find_free_port()) + self._find_free_port(), + self._find_free_port(), + ) self._python_interp = sys.executable def _find_free_port(self): - def __free_port(): - with closing(socket.socket(socket.AF_INET, - socket.SOCK_STREAM)) as s: + with closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as s: s.bind(('', 0)) return s.getsockname()[1] @@ -440,7 +485,7 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_CURRENT_ENDPOINT": w1_ep, } - #update environment + # update environment env0.update(envs) env1.update(envs) @@ -451,15 +496,19 @@ class TestDistBase(unittest.TestCase): tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w") print("tr0_cmd: {}, env: {}\n".format(tr0_cmd, env0)) print("tr1_cmd: {}, env: {}\n".format(tr1_cmd, env1)) - tr0_proc = subprocess.Popen(tr0_cmd.strip().split(), - stdout=subprocess.PIPE, - stderr=tr0_pipe, - env=env0) - - tr1_proc = subprocess.Popen(tr0_cmd.strip().split(), - stdout=subprocess.PIPE, - stderr=tr1_pipe, - env=env1) + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr0_pipe, + env=env0, + ) + + tr1_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr1_pipe, + env=env1, + ) tr0_out, tr0_err = tr0_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate() @@ -473,14 +522,16 @@ class TestDistBase(unittest.TestCase): sys.stderr.write('trainer 0 stderr file: %s\n' % f.read()) with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f: sys.stderr.write('trainer 1 stderr file: %s\n' % f.read()) - return pickle.loads(tr0_out), pickle.loads( - tr1_out), tr0_proc.pid, tr1_proc.pid - - def check_with_place(self, - model_file, - col_type, - check_error_log=False, - need_envs={}): + return ( + pickle.loads(tr0_out), + pickle.loads(tr1_out), + tr0_proc.pid, + tr1_proc.pid, + ) + + def check_with_place( + self, model_file, col_type, check_error_log=False, need_envs={} + ): required_envs = { "FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_eager_delete_tensor_gb": "0.0", @@ -491,7 +542,7 @@ class TestDistBase(unittest.TestCase): "FLAGS_call_stack_level": "2", "GLOG_v": "3", "PADDLE_WITH_GLOO": '0', - "BACKEND": "cncl" + "BACKEND": "cncl", } required_envs.update(need_envs) if check_error_log: @@ -499,8 +550,11 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_logtostderr"] = "1" required_envs["GLOO_LOG_LEVEL"] = "TRACE" tr0_out, tr1_out, pid0, pid1 = self._run_cluster( - model_file, required_envs) + model_file, required_envs + ) self.assertEqual( - tr0_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed') + tr0_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed' + ) self.assertEqual( - tr1_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed') + tr1_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed' + ) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu_baseline.py b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu_baseline.py index f524e47b54a..925eec94dac 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu_baseline.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_sync_batch_norm_op_mlu_baseline.py @@ -29,14 +29,17 @@ paddle.enable_static() class TestSyncBatchNormOp(TestDistBase): - def _setup_config(self): pass def test_identity(self, col_type="identity"): - self.check_with_place("sync_batch_norm_op_mlu.py", - col_type, - check_error_log=True) + envs = {"CNCL_MEM_POOL_MULTI_CLIQUE_ENABLE": "1"} + self.check_with_place( + "sync_batch_norm_op_mlu.py", + col_type, + check_error_log=True, + need_envs=envs, + ) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py new file mode 100644 index 00000000000..d4bdf876076 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py @@ -0,0 +1,299 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +import sys + +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest +import paddle +from paddle.fluid import core +import paddle.fluid as fluid +from paddle.fluid.op import Operator +from paddle.fluid.executor import Executor +from paddle.fluid.framework import _test_eager_guard + +paddle.enable_static() + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(((-1.0) * x))) + + +def YoloBox(x, img_size, attrs): + (n, c, h, w) = x.shape + anchors = attrs['anchors'] + an_num = int((len(anchors) // 2)) + class_num = attrs['class_num'] + conf_thresh = attrs['conf_thresh'] + downsample = attrs['downsample_ratio'] + clip_bbox = attrs['clip_bbox'] + scale_x_y = attrs['scale_x_y'] + iou_aware = attrs['iou_aware'] + iou_aware_factor = attrs['iou_aware_factor'] + bias_x_y = (-0.5) * (scale_x_y - 1.0) + input_h = downsample * h + input_w = downsample * w + if iou_aware: + ioup = x[:, :an_num, :, :] + ioup = np.expand_dims(ioup, axis=(-1)) + x = x[:, an_num:, :, :] + x = x.reshape((n, an_num, (5 + class_num), h, w)).transpose((0, 1, 3, 4, 2)) + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = ( + (grid_x + (sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y)) + bias_x_y + ) / w + pred_box[:, :, :, :, 1] = ( + (grid_y + (sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y)) + bias_x_y + ) / h + anchors = [ + (anchors[i], anchors[(i + 1)]) for i in range(0, len(anchors), 2) + ] + anchors_s = np.array( + [((an_w / input_w), (an_h / input_h)) for (an_w, an_h) in anchors] + ) + anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) + pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w + pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h + if iou_aware: + pred_conf = (sigmoid(x[:, :, :, :, 4:5]) ** (1 - iou_aware_factor)) * ( + sigmoid(ioup) ** iou_aware_factor + ) + else: + pred_conf = sigmoid(x[:, :, :, :, 4:5]) + pred_conf[(pred_conf < conf_thresh)] = 0.0 + pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf + pred_box = pred_box * (pred_conf > 0.0).astype('float32') + pred_box = pred_box.reshape((n, (-1), 4)) + (pred_box[:, :, :2], pred_box[:, :, 2:4]) = ( + (pred_box[:, :, :2] - (pred_box[:, :, 2:4] / 2.0)), + (pred_box[:, :, :2] + (pred_box[:, :, 2:4] / 2.0)), + ) + pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis] + pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis] + pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis] + pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis] + if clip_bbox: + for i in range(len(pred_box)): + pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf) + pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf) + pred_box[i, :, 2] = np.clip( + pred_box[i, :, 2], (-np.inf), (img_size[(i, 1)] - 1) + ) + pred_box[i, :, 3] = np.clip( + pred_box[i, :, 3], (-np.inf), (img_size[(i, 0)] - 1) + ) + return (pred_box, pred_score.reshape((n, (-1), class_num))) + + +class TestYoloBoxOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'yolo_box' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + self.python_api = paddle.vision.ops.yolo_box + x = np.random.random(self.x_shape).astype('float32') + img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32') + self.attrs = { + 'anchors': self.anchors, + 'class_num': self.class_num, + 'conf_thresh': self.conf_thresh, + 'downsample_ratio': self.downsample, + 'clip_bbox': self.clip_bbox, + 'scale_x_y': self.scale_x_y, + 'iou_aware': self.iou_aware, + 'iou_aware_factor': self.iou_aware_factor, + } + self.inputs = {'X': x, 'ImgSize': img_size} + (boxes, scores) = YoloBox(x, img_size, self.attrs) + self.outputs = {'Boxes': boxes, 'Scores': scores} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False, atol=1e-5) + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = ( + self.batch_size, + (an_num * (5 + self.class_num)), + 13, + 13, + ) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = False + self.x_shape = ( + self.batch_size, + (an_num * (5 + self.class_num)), + 13, + 13, + ) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpScaleXY(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = ( + self.batch_size, + (an_num * (5 + self.class_num)), + 13, + 13, + ) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.2 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpIoUAware(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = ( + self.batch_size, + (an_num * (6 + self.class_num)), + 13, + 13, + ) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = True + self.iou_aware_factor = 0.5 + + +class TestYoloBoxDygraph(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + img_size = np.ones((2, 2)).astype('int32') + img_size = paddle.to_tensor(img_size) + x1 = np.random.random([2, 14, 8, 8]).astype('float32') + x1 = paddle.to_tensor(x1) + (boxes, scores) = paddle.vision.ops.yolo_box( + x1, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0, + ) + assert (boxes is not None) and (scores is not None) + x2 = np.random.random([2, 16, 8, 8]).astype('float32') + x2 = paddle.to_tensor(x2) + (boxes, scores) = paddle.vision.ops.yolo_box( + x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0, + iou_aware=True, + iou_aware_factor=0.5, + ) + paddle.enable_static() + + +class TestYoloBoxStatic(unittest.TestCase): + def test_static(self): + x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32') + img_size = paddle.static.data('img_size', [2, 2], 'int32') + (boxes, scores) = paddle.vision.ops.yolo_box( + x1, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0, + ) + assert (boxes is not None) and (scores is not None) + x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32') + (boxes, scores) = paddle.vision.ops.yolo_box( + x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0, + iou_aware=True, + iou_aware_factor=0.5, + ) + assert (boxes is not None) and (scores is not None) + + +class TestYoloBoxOpHW(TestYoloBoxOp): + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = False + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 9) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/tools/dockerfile/Dockerfile.mlu b/tools/dockerfile/Dockerfile.mlu index b3edb25fd54..65ab49dd775 100644 --- a/tools/dockerfile/Dockerfile.mlu +++ b/tools/dockerfile/Dockerfile.mlu @@ -1,15 +1,17 @@ # A image for building paddle binaries -# Update CNTOOLKIT_VERSION, CNNL_VERSION and CNCL_VERSION if using other versions +# Update CNTOOLKIT_VERSION, CNNL_VERSION, CNCL_VERSION and MLUOPS_VERSION if using other versions # # Build: -# - CNTOOLKIT_VERSION 2.8.5 -# - CNNL_VERSION 1.10.5 -# - CNCL_VERSION 1.1.2 +# - CNTOOLKIT_VERSION 3.0.2-1 +# - CNNL_VERSION 1.13.0-1 +# - CNCL_VERSION 1.2.1-1 +# - MLUOPS_VERSION 0.2.0-1 # # Download three packages from FTP (need to connect cambricon AE to get FTP url) -# - cntoolkit_2.8.5.ubuntu18.04_amd64.deb -# - cnnl_1.10.5.ubuntu18.04_amd64.deb -# - cncl_1.1.2.ubuntu18.04_amd64.deb +# - cntoolkit_3.0.2-1.ubuntu18.04_amd64.deb +# - cnnl_1.13.0-1.ubuntu18.04_amd64.deb +# - cncl_1.2.1-1.ubuntu18.04_amd64.deb +# - mluops_0.2.0-1.ubuntu18.04_amd64.deb # copy them to current directory first, then run build commands # # For example: @@ -19,11 +21,13 @@ # (get cntoolkit pkg) # (get cnnl pkg) # (get cncl pkg) +# (get mluops pkg) # # docker build -f Dockerfile.mlu \ -# --build-arg CNTOOLKIT_VERSION=2.8.5 \ -# --build-arg CNNL_VERSION=1.10.5 \ -# --build-arg CNCL_VERSION=1.1.2 \ +# --build-arg CNTOOLKIT_VERSION=3.0.2-1 \ +# --build-arg CNNL_VERSION=1.13.0-1 \ +# --build-arg CNCL_VERSION=1.2.1-1 \ +# --build-arg MLUOPS_VERSION=0.2.0-1 \ # -t paddlepaddle/paddle:latest-dev-mlu . # # without mlu device: @@ -40,12 +44,14 @@ MAINTAINER PaddlePaddle Authors ENV WITH_GPU=OFF -ARG CNTOOLKIT_VERSION=2.8.5 -ARG CNNL_VERSION=1.10.5 -ARG CNCL_VERSION=1.1.2 +ARG CNTOOLKIT_VERSION=3.0.2-1 +ARG CNNL_VERSION=1.13.0-1 +ARG CNCL_VERSION=1.2.1-1 +ARG MLUOPS_VERSION=0.2.0-1 ARG CNTOOLKIT_PKG=cntoolkit_$CNTOOLKIT_VERSION.ubuntu18.04_amd64.deb ARG CNNL_PKG=cnnl_$CNNL_VERSION.ubuntu18.04_amd64.deb ARG CNCL_PKG=cncl_$CNCL_VERSION.ubuntu18.04_amd64.deb +ARG MLUOPS_PKG=mluops_$MLUOPS_VERSION.ubuntu18.04_amd64.deb # install cntoolkit COPY $CNTOOLKIT_PKG ./ @@ -67,6 +73,11 @@ COPY $CNCL_PKG ./ RUN dpkg -i $CNCL_PKG && \ rm -f $CNCL_PKG +# install mluops +COPY $MLUOPS_PKG ./ +RUN dpkg -i $MLUOPS_PKG && \ + rm -f $MLUOPS_PKG + # Clean RUN apt-get clean -y -- GitLab