未验证 提交 6e154fc6 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] cherry-pick from develop to release/2.4 (#48313)

* [MLU] fix compute error of dropout op (#45923)

* [MLU] add mergedAdam kernel. (#45965)

* [MLU] add int64 support for mlu one_hot_v2 (#46313)

* [MLU] fix profiler compile failure (#46208)

* [MLU] add barrier_op kernel. (#46417)

* [MLU] fluid: add mluop (#46429)

* [MLU] add huber_loss kernel. (#46455)

* [MLU] add mlu kernel for add_reduce_max_grad (#45651)
Co-authored-by: Nliupeiyu <liupeiyu@cambricon.com>

* [MLU] add_fluid_mluop_yolo_box (#46573)

* [MLU] fix phi::Tensor compile error of mlu. (#46649)

* [MLU] add fluid MLUOps prior_box (#46585)

* [MLU] fix cmake error (#46772)

* [MLU]fix unittest of sync_bn (#46797)

* [MLU] add masterparam support for mlu adamw. (#46804)

* [MLU] add int64 support for allgather. (#46830)

* [MLU] fix compile error & add mlu blacklist function. (#47439)

* [MLU] fix softmax_with_cross_entropy failed in 370-X8.

* [MLU] fix cncl stuck caused by multiple initializations.

* [MLU] fix code style check.
Co-authored-by: Nqipengh <huangqipeng@cambricon.com>
Co-authored-by: Ncifar10 <41565156+cifar10@users.noreply.github.com>
Co-authored-by: 光明和真理's avatarLux et Veritas <1004239791@qq.com>
Co-authored-by: Nliupeiyu <liupeiyu@cambricon.com>
Co-authored-by: Nronnywang <ronny1996@163.com>
上级 96e974a0
...@@ -15,12 +15,14 @@ set(NEUWARE_LIB_DIR ${NEUWARE_HOME}/lib64) ...@@ -15,12 +15,14 @@ set(NEUWARE_LIB_DIR ${NEUWARE_HOME}/lib64)
include_directories(${NEUWARE_INCLUDE_DIR}) include_directories(${NEUWARE_INCLUDE_DIR})
set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so) set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so)
set(MLUOP_LIB ${NEUWARE_LIB_DIR}/libmluops.so)
set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so) set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so)
set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so) set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so)
set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so) set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so)
generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake") generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake")
set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB}) set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${MLUOP_LIB} ${CNRT_LIB} ${CNDRV_LIB}
${CNPAPI_LIB})
if(WITH_CNCL) if(WITH_CNCL)
message(STATUS "Compile with CNCL!") message(STATUS "Compile with CNCL!")
......
...@@ -146,6 +146,48 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, ...@@ -146,6 +146,48 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
kernel_signature_(std::move(kernel_signature)), kernel_signature_(std::move(kernel_signature)),
phi_kernel_(phi_kernel) {} phi_kernel_(phi_kernel) {}
#ifdef PADDLE_WITH_MLU
static void tokenize(const std::string& ops,
char delim,
std::unordered_set<std::string>* op_set) {
std::string::size_type beg = 0;
for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
++end) {
op_set->insert(ops.substr(beg, end - beg));
beg = end + 1;
}
op_set->insert(ops.substr(beg));
}
static bool is_in_mlu_black_list(const std::string& op_name) {
static bool inited = false;
static std::unordered_set<std::string> mlu_black_list;
static std::mutex s_mtx;
if (!inited) {
std::lock_guard<std::mutex> guard(s_mtx);
if (!inited) {
if (std::getenv("MLU_BLACK_LIST") != nullptr) {
std::string ops(std::getenv("MLU_BLACK_LIST"));
tokenize(ops, ',', &mlu_black_list);
}
inited = true;
VLOG(3) << "MLU Black List: ";
for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end();
++iter) {
VLOG(3) << *iter << " ";
}
}
}
if (mlu_black_list.find(op_name) != mlu_black_list.end()) {
return true;
}
return false;
}
#endif
template <typename VarType> template <typename VarType>
PreparedOp PrepareImpl( PreparedOp PrepareImpl(
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins,
...@@ -194,6 +236,12 @@ PreparedOp PrepareImpl( ...@@ -194,6 +236,12 @@ PreparedOp PrepareImpl(
#endif #endif
#ifdef PADDLE_WITH_MLU
if (is_in_mlu_black_list(op.Type())) {
expected_kernel_key.place_ = platform::CPUPlace();
}
#endif
bool has_phi_kernel = false; bool has_phi_kernel = false;
const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type()); const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/barrier_op.h"
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/cncl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class BarrierOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CNCL)
auto in = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
auto place = ctx.GetPlace();
cnclDataType_t dtype =
platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype()));
int64_t numel = in->numel();
const void* sendbuff = in->data();
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
auto cncl_comm = platform::CNCLCommContext::Instance().Get(rid, place);
auto* comm = cncl_comm->comm();
auto comm_stream = cncl_comm->stream();
auto& dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
cnclReduceOp_t cncl_red_type = cnclSum;
dev_ctx.Wait();
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(
sendbuff, recvbuff, numel, dtype, cncl_red_type, comm, comm_stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream));
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with CNCL."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(barrier, ops::BarrierOpMLUKernel<int>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#if defined(PADDLE_WITH_CNCL) #if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
...@@ -27,15 +28,14 @@ template <typename T> ...@@ -27,15 +28,14 @@ template <typename T>
class CAllGatherOpMLUKernel : public framework::OpKernel<T> { class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
#if defined(PADDLE_WITH_CNCL) #if defined(PADDLE_WITH_CNCL)
auto x = ctx.Input<framework::Tensor>("X"); auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<phi::DenseTensor>("Out");
cnclDataType_t dtype =
platform::ToCNCLDataType(framework::TransToProtoVarType(x->dtype()));
int nranks = ctx.Attr<int>("nranks"); int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::CNCLCommContext::Instance().Get(rid, place); auto comm = platform::CNCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
nranks, nranks,
...@@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel<T> { ...@@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
uint32_t send_numel = x->numel(); uint32_t send_numel = x->numel();
void* send_buff = reinterpret_cast<void*>(const_cast<T*>(x->data<T>())); void* send_buff;
void* recv_buff = reinterpret_cast<void*>(out->data<T>()); void* recv_buff;
phi::DenseTensor in_tensor, out_tensor;
if (framework::TransToProtoVarType(x->dtype()) ==
framework::proto::VarType::INT64) {
// cast from int64 to int32 since cncl do not support int64
in_tensor.mutable_data<int32_t>(x->dims(), place);
out_tensor.mutable_data<int32_t>(out->dims(), place);
MLUCnnlTensorDesc x_int64_desc(*x);
MLUCnnlTensorDesc x_int32_desc(in_tensor);
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT64, VT::INT32);
MLUCnnl::Cast(ctx,
cast_type,
x_int64_desc.get(),
GetBasePtr(x),
x_int32_desc.get(),
GetBasePtr(&in_tensor));
send_buff = reinterpret_cast<void*>(in_tensor.data<int32_t>());
recv_buff = reinterpret_cast<void*>(out_tensor.data<int32_t>());
} else {
in_tensor.ShareDataWith(*x);
out_tensor.ShareDataWith(*out);
send_buff = reinterpret_cast<void*>(in_tensor.data<T>());
recv_buff = reinterpret_cast<void*>(out_tensor.data<T>());
}
mluStream stream = nullptr; mluStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::MLUDeviceContext*>(dev_ctx)->stream(); stream = static_cast<platform::MLUDeviceContext*>(dev_ctx)->stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(in_tensor.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather( PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(
send_buff, recv_buff, send_numel, dtype, comm->comm(), stream)); send_buff, recv_buff, send_numel, dtype, comm->comm(), stream));
if (framework::TransToProtoVarType(x->dtype()) ==
framework::proto::VarType::INT64) {
// cast back from int64 out_tensor to out
MLUCnnlTensorDesc out_int64_desc(*out);
MLUCnnlTensorDesc out_int32_desc(out_tensor);
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64);
MLUCnnl::Cast(ctx,
cast_type,
out_int32_desc.get(),
GetBasePtr(&out_tensor),
out_int64_desc.get(),
GetBasePtr(out));
}
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with MLU.")); "PaddlePaddle should compile with MLU."));
...@@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather, ...@@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather,
ops::CAllGatherOpMLUKernel<int>, ops::CAllGatherOpMLUKernel<int>,
ops::CAllGatherOpMLUKernel<int8_t>, ops::CAllGatherOpMLUKernel<int8_t>,
ops::CAllGatherOpMLUKernel<int16_t>, ops::CAllGatherOpMLUKernel<int16_t>,
ops::CAllGatherOpMLUKernel<int64_t>,
ops::CAllGatherOpMLUKernel<plat::float16>); ops::CAllGatherOpMLUKernel<plat::float16>);
...@@ -42,19 +42,23 @@ if(WITH_XPU) ...@@ -42,19 +42,23 @@ if(WITH_XPU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_xpu.cc) iou_similarity_op_xpu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
elseif(WITH_MLU) elseif(WITH_MLU)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_mlu.cc) iou_similarity_op_mlu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_mlu.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op_mlu.cc)
elseif(WITH_ASCEND_CL) elseif(WITH_ASCEND_CL)
detection_library(iou_similarity_op SRCS iou_similarity_op.cc detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op_npu.cc) iou_similarity_op_npu.cc)
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_npu.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op_npu.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
else() else()
detection_library(iou_similarity_op SRCS iou_similarity_op.cc detection_library(iou_similarity_op SRCS iou_similarity_op.cc
iou_similarity_op.cu) iou_similarity_op.cu)
detection_library(prior_box_op SRCS prior_box_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
# detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) # detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
endif() endif()
...@@ -73,7 +77,6 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc) ...@@ -73,7 +77,6 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc)
detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc) detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc
box_decoder_and_assign_op.cu) box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/detection/prior_box_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class PriorBoxMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* image = ctx.Input<phi::DenseTensor>("Image");
auto* boxes = ctx.Output<phi::DenseTensor>("Boxes");
auto* variances = ctx.Output<phi::DenseTensor>("Variances");
float step_w = ctx.Attr<float>("step_w");
float step_h = ctx.Attr<float>("step_h");
float offset = ctx.Attr<float>("offset");
bool clip = ctx.Attr<bool>("clip");
bool min_max_aspect_ratios_order =
ctx.Attr<bool>("min_max_aspect_ratios_order");
int im_width = image->dims()[3];
int im_height = image->dims()[2];
int width = input->dims()[3];
int height = input->dims()[2];
auto aspect_ratios = ctx.Attr<std::vector<float>>("aspect_ratios");
bool flip = ctx.Attr<bool>("flip");
std::vector<float> new_aspect_ratios;
ExpandAspectRatios(aspect_ratios, flip, &new_aspect_ratios);
auto& dev_ctx = ctx.template device_context<platform::MLUDeviceContext>();
phi::DenseTensor ratios;
paddle::framework::TensorFromVector(new_aspect_ratios, dev_ctx, &ratios);
MLUOpTensorDesc new_aspect_ratios_desc(ratios);
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
phi::DenseTensor min;
paddle::framework::TensorFromVector(min_sizes, dev_ctx, &min);
MLUOpTensorDesc min_sizes_desc(min);
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
phi::DenseTensor max;
paddle::framework::TensorFromVector(max_sizes, dev_ctx, &max);
MLUOpTensorDesc max_sizes_desc(max);
auto variances_attr = ctx.Attr<std::vector<float>>("variances");
phi::DenseTensor var_tensor;
paddle::framework::TensorFromVector(variances_attr, dev_ctx, &var_tensor);
MLUOpTensorDesc variances_attr_desc(var_tensor);
auto place = ctx.GetPlace();
boxes->mutable_data<T>(place);
variances->mutable_data<T>(place);
MLUOpTensorDesc var_desc(*variances);
MLUOpTensorDesc output_desc(*boxes);
MLUOP::OpPriorBox(ctx,
min_sizes_desc.get(),
GetBasePtr(&min),
new_aspect_ratios_desc.get(),
GetBasePtr(&ratios),
variances_attr_desc.get(),
GetBasePtr(&var_tensor),
max_sizes_desc.get(),
GetBasePtr(&max),
height,
width,
im_height,
im_width,
step_h,
step_w,
offset,
clip,
min_max_aspect_ratios_order,
output_desc.get(),
GetBasePtr(boxes),
var_desc.get(),
GetBasePtr(variances));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(prior_box, ops::PriorBoxMLUKernel<float>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class YoloBoxMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<phi::DenseTensor>("X");
auto* img_size = ctx.Input<phi::DenseTensor>("ImgSize");
auto* boxes = ctx.Output<phi::DenseTensor>("Boxes");
auto* scores = ctx.Output<phi::DenseTensor>("Scores");
const std::vector<int> anchors = ctx.Attr<std::vector<int>>("anchors");
auto class_num = ctx.Attr<int>("class_num");
auto conf_thresh = ctx.Attr<float>("conf_thresh");
auto downsample_ratio = ctx.Attr<int>("downsample_ratio");
auto clip_bbox = ctx.Attr<bool>("clip_bbox");
auto scale = ctx.Attr<float>("scale_x_y");
auto iou_aware = ctx.Attr<bool>("iou_aware");
auto iou_aware_factor = ctx.Attr<float>("iou_aware_factor");
int anchor_num = anchors.size() / 2;
int64_t size = anchors.size();
auto dim_x = x->dims();
int n = dim_x[0];
int s = anchor_num;
int h = dim_x[2];
int w = dim_x[3];
// The output of mluOpYoloBox: A 4-D tensor with shape [N, anchor_num, 4,
// H*W], the coordinates of boxes, and a 4-D tensor with shape [N,
// anchor_num, :attr:`class_num`, H*W], the classification scores of boxes.
std::vector<int64_t> boxes_dim_mluops({n, s, 4, h * w});
std::vector<int64_t> scores_dim_mluops({n, s, class_num, h * w});
// In Paddle framework: A 3-D tensor with shape [N, M, 4], the coordinates
// of boxes, and a 3-D tensor with shape [N, M, :attr:`class_num`], the
// classification scores of boxes.
std::vector<int64_t> boxes_out_dim({n, s, h * w, 4});
std::vector<int64_t> scores_out_dim({n, s, h * w, class_num});
auto& dev_ctx = ctx.template device_context<MLUDeviceContext>();
phi::DenseTensor boxes_tensor_mluops =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({n, s, 4, h * w}, dev_ctx);
phi::DenseTensor scores_tensor_mluops =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({n, s, class_num, h * w},
dev_ctx);
MLUOpTensorDesc boxes_trans_desc_mluops(
4, boxes_dim_mluops.data(), ToMluOpDataType<T>());
MLUCnnlTensorDesc boxes_trans_desc_cnnl(
4, boxes_dim_mluops.data(), ToCnnlDataType<T>());
MLUOpTensorDesc scores_trans_desc_mluops(
4, scores_dim_mluops.data(), ToMluOpDataType<T>());
MLUCnnlTensorDesc scores_trans_desc_cnnl(
4, scores_dim_mluops.data(), ToCnnlDataType<T>());
boxes->mutable_data<T>(ctx.GetPlace());
scores->mutable_data<T>(ctx.GetPlace());
FillMLUTensorWithHostValue(ctx, static_cast<T>(0), boxes);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0), scores);
MLUOpTensorDesc x_desc(*x, MLUOP_LAYOUT_ARRAY, ToMluOpDataType<T>());
MLUOpTensorDesc img_size_desc(
*img_size, MLUOP_LAYOUT_ARRAY, ToMluOpDataType<int32_t>());
Tensor anchors_temp(framework::TransToPhiDataType(VT::INT32));
anchors_temp.Resize({size});
paddle::framework::TensorFromVector(
anchors, ctx.device_context(), &anchors_temp);
MLUOpTensorDesc anchors_desc(anchors_temp);
MLUCnnlTensorDesc boxes_desc_cnnl(
4, boxes_out_dim.data(), ToCnnlDataType<T>());
MLUCnnlTensorDesc scores_desc_cnnl(
4, scores_out_dim.data(), ToCnnlDataType<T>());
MLUOP::OpYoloBox(ctx,
x_desc.get(),
GetBasePtr(x),
img_size_desc.get(),
GetBasePtr(img_size),
anchors_desc.get(),
GetBasePtr(&anchors_temp),
class_num,
conf_thresh,
downsample_ratio,
clip_bbox,
scale,
iou_aware,
iou_aware_factor,
boxes_trans_desc_mluops.get(),
GetBasePtr(&boxes_tensor_mluops),
scores_trans_desc_mluops.get(),
GetBasePtr(&scores_tensor_mluops));
const std::vector<int> perm = {0, 1, 3, 2};
// transpose the boxes from [N, S, 4, H*W] to [N, S, H*W, 4]
MLUCnnl::Transpose(ctx,
perm,
4,
boxes_trans_desc_cnnl.get(),
GetBasePtr(&boxes_tensor_mluops),
boxes_desc_cnnl.get(),
GetBasePtr(boxes));
// transpose the scores from [N, S, class_num, H*W] to [N, S, H*W,
// class_num]
MLUCnnl::Transpose(ctx,
perm,
4,
scores_trans_desc_cnnl.get(),
GetBasePtr(&scores_tensor_mluops),
scores_desc_cnnl.get(),
GetBasePtr(scores));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(yolo_box, ops::YoloBoxMLUKernel<float>);
...@@ -39,8 +39,17 @@ class DropoutMLUKernel : public framework::OpKernel<T> { ...@@ -39,8 +39,17 @@ class DropoutMLUKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc x_desc(*x); MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc out_desc(*out); MLUCnnlTensorDesc out_desc(*out);
if (!is_test) { if (is_test && is_upscale) {
// exec dropout op for training only. // dropout op for inference: out = input.
framework::TensorCopy(
*x,
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
out);
return;
} else if (!is_test) {
// dropout op for training: out = input * mask / ( 1.0 - dropout_prob ) or
// out = input * mask.
int seed_data = 0; int seed_data = 0;
if (seed_tensor) { if (seed_tensor) {
if (platform::is_mlu_place(seed_tensor->place())) { if (platform::is_mlu_place(seed_tensor->place())) {
...@@ -79,50 +88,44 @@ class DropoutMLUKernel : public framework::OpKernel<T> { ...@@ -79,50 +88,44 @@ class DropoutMLUKernel : public framework::OpKernel<T> {
const int device_id = ctx.GetPlace().GetDeviceId(); const int device_id = ctx.GetPlace().GetDeviceId();
auto mlu_gen_random = GetMLURandomGenerator(ctx, device_id, seed_data); auto mlu_gen_random = GetMLURandomGenerator(ctx, device_id, seed_data);
const float prob = is_upscale ? dropout_prob : 0.0f; // compute out = input * mask / ( 1.0 - dropout_prob )
MLUCnnl::FusedDropout(ctx, MLUCnnl::FusedDropout(ctx,
mlu_gen_random->get(), mlu_gen_random->get(),
x_desc.get(), x_desc.get(),
GetBasePtr(x), GetBasePtr(x),
prob, dropout_prob,
GetBasePtr(&(mlu_gen_random->get_state())), GetBasePtr(&(mlu_gen_random->get_state())),
mask_desc.get(), mask_desc.get(),
GetBasePtr(mask), GetBasePtr(mask),
out_desc.get(), out_desc.get(),
GetBasePtr(out)); GetBasePtr(out));
} else {
// exec dropout op for inference only.
if (is_upscale) { if (is_upscale) {
framework::TensorCopy( return;
*x,
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
out);
} else {
auto scale = static_cast<T>(1.0f - dropout_prob);
Tensor scale_tensor(x->dtype());
scale_tensor.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&scale,
scale_desc.get(),
GetBasePtr(&scale_tensor));
auto data_type = ToCnnlDataType<T>();
MLUCnnlOpTensorDesc op_tensor_desc(
CNNL_OP_TENSOR_MUL, data_type, CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx,
op_tensor_desc.get(),
x_desc.get(),
GetBasePtr(x),
scale_desc.get(),
GetBasePtr(&scale_tensor),
out_desc.get(),
GetBasePtr(out),
data_type);
} }
} }
// In downgrade_in_infer mode, need to multiply (1.0f - dropout_prob).
Tensor scale_tensor(x->dtype());
Tensor bias_tensor(x->dtype());
scale_tensor.mutable_data<T>({1}, ctx.GetPlace());
bias_tensor.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnlTensorDesc bias_desc(bias_tensor);
FillMLUTensorWithHostValue(
ctx, static_cast<T>(1.0f - dropout_prob), &scale_tensor);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.0f), &bias_tensor);
MLUCnnl::Scale(ctx,
0,
is_test ? x_desc.get() : out_desc.get(),
is_test ? GetBasePtr(x) : GetBasePtr(out),
scale_desc.get(),
GetBasePtr(&scale_tensor),
bias_desc.get(),
GetBasePtr(&bias_tensor),
out_desc.get(),
GetBasePtr(out));
} }
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
template <typename T>
class HuberLossMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = GetDevCtxFromCTX(ctx);
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* residual = ctx.Output<Tensor>("Residual");
auto* out = ctx.Output<Tensor>("Out");
auto delta = ctx.Attr<float>("delta");
auto place = ctx.GetPlace();
// compute y-x
cnnlDataType_t data_type = ToCnnlDataType<T>();
residual->mutable_data<T>(x->dims(), place);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlOpTensorDesc sub_op_desc(
CNNL_OP_TENSOR_SUB, data_type, CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx,
sub_op_desc.get(),
x_desc.get(),
GetBasePtr(y),
x_desc.get(),
GetBasePtr(x),
x_desc.get(),
GetBasePtr(residual),
data_type);
// compute smoothl1loss
out->mutable_data<T>(x->dims(), place);
cnnlSmoothL1LossAlgorithm_t smoothl1_algo =
CNNL_SMOOTHL1LOSS_REDUCTION_NONE; // defines whether to do reduction
// here
MLUCnnl::SmoothL1LossForward(ctx,
x_desc.get(),
GetBasePtr(x),
x_desc.get(), /* target has same shape as x */
GetBasePtr(y),
static_cast<float>(delta),
smoothl1_algo,
x_desc.get(), /* out has same shape as x */
GetBasePtr(out));
// compute multiply by delta
Tensor scale_tensor, bias_tensor;
scale_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
bias_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
FillMLUTensorWithHostValue(ctx, static_cast<T>(delta), &scale_tensor);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.f), &bias_tensor);
const int axis = std::max(out->dims().size() - 1, 0);
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnlTensorDesc bias_desc(bias_tensor);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Scale(ctx,
axis,
out_desc.get(),
GetBasePtr(out),
scale_desc.get(),
GetBasePtr(&scale_tensor),
bias_desc.get(),
GetBasePtr(&bias_tensor),
out_desc.get(),
GetBasePtr(out));
}
};
template <typename T>
class HuberLossGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = GetDevCtxFromCTX(ctx);
auto* residual = ctx.Input<Tensor>("Residual");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto delta = ctx.Attr<float>("delta");
auto place = ctx.GetPlace();
Tensor t_grad_rd;
t_grad_rd =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(residual->dims(), dev_ctx);
MLUCnnlTensorDesc t_grad_rd_desc(t_grad_rd);
if (dx || dy) {
Tensor t_zero;
t_zero =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(residual->dims(), dev_ctx);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.f), &t_zero);
MLUCnnlTensorDesc residual_desc(*residual);
MLUCnnlTensorDesc dout_desc(*dout);
cnnlSmoothL1LossAlgorithm_t smoothl1_algo =
CNNL_SMOOTHL1LOSS_REDUCTION_NONE; // defines whether to do reduction
// here
MLUCnnl::SmoothL1LossBackward(ctx,
residual_desc.get(),
GetBasePtr(residual),
residual_desc.get(),
GetBasePtr(&t_zero),
dout_desc.get(),
GetBasePtr(dout),
static_cast<float>(delta),
smoothl1_algo,
t_grad_rd_desc.get(),
GetBasePtr(&t_grad_rd));
}
// compute multiply by delta
Tensor scale_tensor, bias_tensor;
scale_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
bias_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.f), &bias_tensor);
const int axis = std::max(t_grad_rd.dims().size() - 1, 0);
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnlTensorDesc bias_desc(bias_tensor);
if (dx) {
dx->mutable_data<T>(place);
FillMLUTensorWithHostValue(ctx, static_cast<T>(-delta), &scale_tensor);
MLUCnnlTensorDesc out_desc(*dx);
MLUCnnl::Scale(ctx,
axis,
t_grad_rd_desc.get(),
GetBasePtr(&t_grad_rd),
scale_desc.get(),
GetBasePtr(&scale_tensor),
bias_desc.get(),
GetBasePtr(&bias_tensor),
out_desc.get(),
GetBasePtr(dx));
}
if (dy) {
dy->mutable_data<T>(place);
FillMLUTensorWithHostValue(ctx, static_cast<T>(delta), &scale_tensor);
MLUCnnlTensorDesc out_desc(*dy);
MLUCnnl::Scale(ctx,
axis,
t_grad_rd_desc.get(),
GetBasePtr(&t_grad_rd),
scale_desc.get(),
GetBasePtr(&scale_tensor),
bias_desc.get(),
GetBasePtr(&bias_tensor),
out_desc.get(),
GetBasePtr(dy));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(huber_loss,
ops::HuberLossMLUKernel<float>,
ops::HuberLossMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(huber_loss_grad,
ops::HuberLossGradMLUKernel<float>,
ops::HuberLossGradMLUKernel<plat::float16>);
...@@ -256,6 +256,186 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() { ...@@ -256,6 +256,186 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() {
} }
} }
class MLUOpTensorDescPool {
public:
mluOpTensorDescriptor_t Pop() {
mluOpTensorDescriptor_t raw_desc;
if (q_.try_dequeue(raw_desc)) {
return raw_desc;
} else {
mluOpCreateTensorDescriptor(&raw_desc);
return raw_desc;
}
}
void Recycle(mluOpTensorDescriptor_t desc) {
mluOpResetTensorDescriptor(desc);
q_.enqueue(desc);
}
~MLUOpTensorDescPool() {
auto size = q_.size_approx();
if (size > 0) {
std::vector<mluOpTensorDescriptor_t> vec(size);
q_.try_dequeue_bulk(vec.data(), size);
for (auto desc : vec) {
mluOpDestroyTensorDescriptor(desc);
}
}
}
private:
moodycamel::ConcurrentQueue<mluOpTensorDescriptor_t> q_;
};
static MLUOpTensorDescPool g_mluop_tensor_desc_pool;
MLUOpTensorDesc& MLUOpTensorDesc::operator=(MLUOpTensorDesc&& rhs) {
if (raw_tensor_desc) {
g_mluop_tensor_desc_pool.Recycle(raw_tensor_desc);
}
raw_tensor_desc = rhs.raw_tensor_desc;
rhs.raw_tensor_desc = nullptr;
return *this;
}
MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim,
const int dim_sizes[],
const mluOpDataType_t tensor_dtype) {
raw_tensor_desc = g_mluop_tensor_desc_pool.Pop();
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc,
MLUOP_LAYOUT_ARRAY,
tensor_dtype,
tensor_dim,
dim_sizes));
}
MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim,
const int dim_sizes[],
const mluOpDataType_t tensor_dtype,
const mluOpTensorLayout_t layout) {
raw_tensor_desc = g_mluop_tensor_desc_pool.Pop();
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(
raw_tensor_desc, layout, tensor_dtype, tensor_dim, dim_sizes));
}
MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim,
const int dim_sizes[],
const mluOpDataType_t tensor_dtype,
int position)
: MLUOpTensorDesc(tensor_dim, dim_sizes, tensor_dtype) {
PADDLE_ENFORCE_MLU_SUCCESS(
mluOpSetTensorDescriptorPosition(raw_tensor_desc, position));
}
MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim,
const int64_t dim_sizes[],
const mluOpDataType_t tensor_dtype) {
std::vector<int> dim_sizes_int32(tensor_dim);
std::vector<int64_t>::const_iterator int64_cbegin(dim_sizes);
std::vector<int64_t>::const_iterator int64_cend(dim_sizes + tensor_dim);
std::transform(int64_cbegin,
int64_cend,
dim_sizes_int32.begin(),
&CheckedNarrowing<int64_t, int>);
raw_tensor_desc = g_mluop_tensor_desc_pool.Pop();
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc,
MLUOP_LAYOUT_ARRAY,
tensor_dtype,
tensor_dim,
dim_sizes_int32.data()));
}
MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim,
const int64_t dim_sizes[],
const mluOpDataType_t tensor_dtype,
const mluOpTensorLayout_t layout) {
std::vector<int> dim_sizes_int32(tensor_dim);
std::vector<int64_t>::const_iterator int64_cbegin(dim_sizes);
std::vector<int64_t>::const_iterator int64_cend(dim_sizes + tensor_dim);
std::transform(int64_cbegin,
int64_cend,
dim_sizes_int32.begin(),
&CheckedNarrowing<int64_t, int>);
raw_tensor_desc = g_mluop_tensor_desc_pool.Pop();
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc,
layout,
tensor_dtype,
tensor_dim,
dim_sizes_int32.data()));
}
MLUOpTensorDesc::MLUOpTensorDesc(const int tensor_dim,
const int64_t dim_sizes[],
const mluOpDataType_t tensor_dtype,
int position) {
std::vector<int> dim_sizes_int32(tensor_dim);
std::vector<int64_t>::const_iterator int64_cbegin(dim_sizes);
std::vector<int64_t>::const_iterator int64_cend(dim_sizes + tensor_dim);
std::transform(int64_cbegin,
int64_cend,
dim_sizes_int32.begin(),
&CheckedNarrowing<int64_t, int>);
raw_tensor_desc = g_mluop_tensor_desc_pool.Pop();
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(raw_tensor_desc,
MLUOP_LAYOUT_ARRAY,
tensor_dtype,
tensor_dim,
dim_sizes_int32.data()));
PADDLE_ENFORCE_MLU_SUCCESS(
mluOpSetTensorDescriptorPosition(raw_tensor_desc, position));
}
MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor,
const mluOpTensorLayout_t layout,
const mluOpDataType_t tensor_dtype) {
auto dims = phi::vectorize<int>(tensor.dims());
int tensor_dim = dims.size();
raw_tensor_desc = g_mluop_tensor_desc_pool.Pop();
if (tensor_dim == 0) {
int scalar_dims[1] = {1};
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptor(
raw_tensor_desc, layout, tensor_dtype, 1, scalar_dims));
} else {
std::vector<int> tensor_dim_sizes_int(dims.begin(), dims.end());
PADDLE_ENFORCE_MLU_SUCCESS(
mluOpSetTensorDescriptor(raw_tensor_desc,
layout,
tensor_dtype,
tensor_dim,
tensor_dim_sizes_int.data()));
}
}
MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor)
: MLUOpTensorDesc(
tensor, MLUOP_LAYOUT_ARRAY, ToMluOpDataType(tensor.dtype())) {}
MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor,
mluOpTensorLayout_t layout,
const mluOpDataType_t tensor_dtype,
int position)
: MLUOpTensorDesc(tensor, layout, tensor_dtype) {
PADDLE_ENFORCE_MLU_SUCCESS(
mluOpSetTensorDescriptorPosition(raw_tensor_desc, position));
}
MLUOpTensorDesc::MLUOpTensorDesc(const Tensor& tensor,
mluOpTensorLayout_t layout,
const mluOpDataType_t tensor_dtype,
int position,
float scale)
: MLUOpTensorDesc(tensor, layout, tensor_dtype) {
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetTensorDescriptorPositionAndScale(
raw_tensor_desc, position, scale));
}
MLUOpTensorDesc::~MLUOpTensorDesc() {
if (raw_tensor_desc) {
g_mluop_tensor_desc_pool.Recycle(raw_tensor_desc);
}
}
MLUCnnlActivationDesc::MLUCnnlActivationDesc( MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof) { const cnnlActivationMode_t act_mode, const float ceof) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
...@@ -1563,17 +1743,35 @@ MLURNNDesc::~MLURNNDesc() { ...@@ -1563,17 +1743,35 @@ MLURNNDesc::~MLURNNDesc() {
void* indices_out) { void* indices_out) {
cnnlHandle_t handle = GetHandleFromCTX(ctx); cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlTopKTensor(handle, size_t workspace_size;
input_desc, PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetTopKTensorWorkspaceSize(handle,
input, input_desc,
k, k,
dim, dim,
largest, largest,
sorted, values_output_desc,
values_output_desc, indices_output_desc,
values_out, &workspace_size));
indices_output_desc,
indices_out)); auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlTopKTensor_v3(handle,
input_desc,
input,
k,
dim,
largest,
sorted,
false /*lower_index_first*/,
workspace_ptr,
workspace_size,
values_output_desc,
values_out,
indices_output_desc,
indices_out));
} }
/* static */ void MLUCnnl::StridedSlice( /* static */ void MLUCnnl::StridedSlice(
...@@ -4527,6 +4725,78 @@ MLURNNDesc::~MLURNNDesc() { ...@@ -4527,6 +4725,78 @@ MLURNNDesc::~MLURNNDesc() {
output)); output));
} }
/* static */ void MLUCnnl::SmoothL1LossForward(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t t_desc,
const void* target,
const float beta,
const cnnlSmoothL1LossAlgorithm_t algorithm,
const cnnlTensorDescriptor_t y_desc,
void* y) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetSmoothL1LossForwardWorkspaceSize(
handle, x_desc, algorithm, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSmoothL1LossForward_v2(handle,
x_desc,
x,
t_desc,
target,
beta,
algorithm,
workspace_ptr,
workspace_size,
y_desc,
y));
}
/* static */ void MLUCnnl::SmoothL1LossBackward(
const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t target_desc,
const void* target,
const cnnlTensorDescriptor_t dy_desc,
const void* dy,
const float beta,
const cnnlSmoothL1LossAlgorithm_t algorithm,
const cnnlTensorDescriptor_t dx_desc,
void* dx) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetSmoothL1LossBackwardWorkspaceSize(
handle, x_desc, algorithm, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSmoothL1LossBackward_v2(handle,
x_desc,
x,
target_desc,
target,
dy_desc,
dy,
beta,
algorithm,
workspace_ptr,
workspace_size,
dx_desc,
dx));
}
/* static */ void MLUCnnl::EmbeddingForward( /* static */ void MLUCnnl::EmbeddingForward(
const ExecutionContext& ctx, const ExecutionContext& ctx,
const int padding_idx, const int padding_idx,
...@@ -5148,5 +5418,94 @@ MLURNNDesc::~MLURNNDesc() { ...@@ -5148,5 +5418,94 @@ MLURNNDesc::~MLURNNDesc() {
diff_x)); diff_x));
} }
/* static */ void MLUOP::OpYoloBox(const ExecutionContext& ctx,
const mluOpTensorDescriptor_t x_desc,
const void* x,
const mluOpTensorDescriptor_t img_size_desc,
const void* img_size,
const mluOpTensorDescriptor_t anchors_desc,
const void* anchors,
const int class_num,
const float conf_thresh,
const int downsample_ratio,
const bool clip_bbox,
const float scale,
const bool iou_aware,
const float iou_aware_factor,
const mluOpTensorDescriptor_t boxes_desc,
void* boxes,
const mluOpTensorDescriptor_t scores_desc,
void* scores) {
mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(mluOpYoloBox(handle,
x_desc,
x,
img_size_desc,
img_size,
anchors_desc,
anchors,
class_num,
conf_thresh,
downsample_ratio,
clip_bbox,
scale,
iou_aware,
iou_aware_factor,
boxes_desc,
boxes,
scores_desc,
scores));
}
/* static */ void MLUOP::OpPriorBox(
const ExecutionContext& ctx,
const mluOpTensorDescriptor_t min_sizes_desc,
const void* min_sizes,
const mluOpTensorDescriptor_t aspect_ratios_desc,
const void* aspect_ratios,
const mluOpTensorDescriptor_t variances_desc,
const void* variances,
const mluOpTensorDescriptor_t max_sizes_desc,
const void* max_sizes,
const int height,
const int width,
const int im_height,
const int im_width,
const float step_h,
const float step_w,
const float offset,
const bool clip,
const bool min_max_aspect_ratios_order,
const mluOpTensorDescriptor_t output_desc,
void* output,
const mluOpTensorDescriptor_t var_desc,
void* var) {
mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(mluOpPriorBox(handle,
min_sizes_desc,
min_sizes,
aspect_ratios_desc,
aspect_ratios,
variances_desc,
variances,
max_sizes_desc,
max_sizes,
height,
width,
im_height,
im_width,
step_h,
step_w,
offset,
clip,
min_max_aspect_ratios_order,
output_desc,
output,
var_desc,
var));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <cn_api.h> #include <cn_api.h>
#include <cnnl.h> #include <cnnl.h>
#include <concurrentqueue.h> #include <concurrentqueue.h>
#include <mlu_op.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -138,6 +139,54 @@ inline cnnlDataType_t ToCnnlDataType() { ...@@ -138,6 +139,54 @@ inline cnnlDataType_t ToCnnlDataType() {
return ToCnnlDataType(type); return ToCnnlDataType(type);
} }
inline mluOpDataType_t ToMluOpDataType(
const paddle::experimental::DataType& dtype) {
mluOpDataType_t type = MLUOP_DTYPE_FLOAT;
switch (dtype) {
case DataType::FLOAT16:
type = MLUOP_DTYPE_HALF;
break;
case DataType::FLOAT32:
type = MLUOP_DTYPE_FLOAT;
break;
case DataType::FLOAT64:
type = MLUOP_DTYPE_DOUBLE;
break;
case DataType::INT8:
type = MLUOP_DTYPE_INT8;
break;
case DataType::INT16:
type = MLUOP_DTYPE_INT16;
break;
case DataType::INT32:
type = MLUOP_DTYPE_INT32;
break;
case DataType::INT64:
type = MLUOP_DTYPE_INT64;
break;
case DataType::BOOL:
type = MLUOP_DTYPE_BOOL;
break;
case DataType::UINT8:
type = MLUOP_DTYPE_UINT8;
break;
default:
break;
}
return type;
}
inline mluOpDataType_t ToMluOpDataType(
const paddle::framework::proto::VarType::Type& type) {
return ToMluOpDataType(framework::TransToPhiDataType(type));
}
template <typename T>
inline mluOpDataType_t ToMluOpDataType() {
auto type = framework::ToDataType(std::type_index(typeid(T)));
return ToMluOpDataType(type);
}
// Converts (via narrowing) a type T value to a type U, and checks that the // Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion. // value has no value change due to the conversion.
template <typename WideT, typename NarrowT> template <typename WideT, typename NarrowT>
...@@ -152,6 +201,10 @@ inline static cnnlHandle_t GetHandleFromCTX(const ExecutionContext& ctx) { ...@@ -152,6 +201,10 @@ inline static cnnlHandle_t GetHandleFromCTX(const ExecutionContext& ctx) {
return ctx.template device_context<MLUDeviceContext>().cnnl_handle(); return ctx.template device_context<MLUDeviceContext>().cnnl_handle();
} }
inline static mluOpHandle_t GetMLUOpHandleFromCTX(const ExecutionContext& ctx) {
return ctx.template device_context<MLUDeviceContext>().mluOp_handle();
}
inline static const MLUDeviceContext& GetDevCtxFromCTX( inline static const MLUDeviceContext& GetDevCtxFromCTX(
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
return ctx.template device_context<MLUDeviceContext>(); return ctx.template device_context<MLUDeviceContext>();
...@@ -281,6 +334,74 @@ class MLUCnnlTensorDesc { ...@@ -281,6 +334,74 @@ class MLUCnnlTensorDesc {
cnnlTensorDescriptor_t raw_tensor_desc = nullptr; cnnlTensorDescriptor_t raw_tensor_desc = nullptr;
}; };
class MLUOpTensorDesc {
public:
MLUOpTensorDesc() {}
// SE_DISALLOW_COPY_AND_ASSIGN
MLUOpTensorDesc(const MLUOpTensorDesc& desc) = delete;
MLUOpTensorDesc& operator=(const MLUOpTensorDesc&) = delete;
MLUOpTensorDesc(MLUOpTensorDesc&& rhs)
: raw_tensor_desc(rhs.raw_tensor_desc) {
rhs.raw_tensor_desc = nullptr;
}
MLUOpTensorDesc& operator=(MLUOpTensorDesc&& rhs);
MLUOpTensorDesc(const int tensor_dim,
const int dim_sizes[],
const mluOpDataType_t tensor_dtype);
MLUOpTensorDesc(const int tensor_dim,
const int dim_sizes[],
const mluOpDataType_t tensor_dtype,
const mluOpTensorLayout_t layout);
MLUOpTensorDesc(const int tensor_dim,
const int dim_sizes[],
const mluOpDataType_t tensor_dtype,
int position);
MLUOpTensorDesc(const int tensor_dim,
const int64_t dim_sizes[],
const mluOpDataType_t tensor_dtype);
MLUOpTensorDesc(const int tensor_dim,
const int64_t dim_sizes[],
const mluOpDataType_t tensor_dtype,
const mluOpTensorLayout_t layout);
MLUOpTensorDesc(const int tensor_dim,
const int64_t dim_sizes[],
const mluOpDataType_t tensor_dtype,
int position);
MLUOpTensorDesc(const Tensor& tensor,
const mluOpTensorLayout_t layout,
const mluOpDataType_t tensor_dtype);
explicit MLUOpTensorDesc(const Tensor& tensor);
MLUOpTensorDesc(const Tensor& tensor,
mluOpTensorLayout_t layout,
const mluOpDataType_t tensor_dtype,
int position);
MLUOpTensorDesc(const Tensor& tensor,
mluOpTensorLayout_t layout,
const mluOpDataType_t tensor_dtype,
int position,
float scale);
~MLUOpTensorDesc();
const mluOpTensorDescriptor_t get() const { return raw_tensor_desc; }
private:
mluOpTensorDescriptor_t raw_tensor_desc = nullptr;
};
class MLUCnnlActivationDesc { class MLUCnnlActivationDesc {
public: public:
MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete; MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete;
...@@ -1921,6 +2042,28 @@ class MLUCnnl { ...@@ -1921,6 +2042,28 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc, const cnnlTensorDescriptor_t output_desc,
void* output); void* output);
static void SmoothL1LossForward(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t t_desc,
const void* target,
const float beta,
const cnnlSmoothL1LossAlgorithm_t algorithm,
const cnnlTensorDescriptor_t y_desc,
void* y);
static void SmoothL1LossBackward(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t target_desc,
const void* target,
const cnnlTensorDescriptor_t dy_desc,
const void* dy,
const float beta,
const cnnlSmoothL1LossAlgorithm_t algorithm,
const cnnlTensorDescriptor_t dx_desc,
void* dx);
static void EmbeddingForward(const ExecutionContext& ctx, static void EmbeddingForward(const ExecutionContext& ctx,
const int padding_idx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const cnnlTensorDescriptor_t weight_desc,
...@@ -2149,6 +2292,50 @@ class MLUCnnl { ...@@ -2149,6 +2292,50 @@ class MLUCnnl {
void* diff_x); void* diff_x);
}; };
class MLUOP {
public:
static void OpYoloBox(const ExecutionContext& ctx,
const mluOpTensorDescriptor_t x_desc,
const void* x,
const mluOpTensorDescriptor_t img_size_desc,
const void* img_size,
const mluOpTensorDescriptor_t anchors_desc,
const void* anchors,
const int class_num,
const float conf_thresh,
const int downsample_ratio,
const bool clip_bbox,
const float scale,
const bool iou_aware,
const float iou_aware_factor,
const mluOpTensorDescriptor_t boxes_desc,
void* boxes,
const mluOpTensorDescriptor_t scores_desc,
void* scores);
static void OpPriorBox(const ExecutionContext& ctx,
const mluOpTensorDescriptor_t min_sizes_desc,
const void* min_sizes,
const mluOpTensorDescriptor_t aspect_ratios_desc,
const void* aspect_ratios,
const mluOpTensorDescriptor_t variances_desc,
const void* variances,
const mluOpTensorDescriptor_t max_sizes_desc,
const void* max_sizes,
const int height,
const int width,
const int im_height,
const int im_width,
const float step_h,
const float step_w,
const float offset,
const bool clip,
const bool min_max_aspect_ratios_order,
const mluOpTensorDescriptor_t output_desc,
void* output,
const mluOpTensorDescriptor_t var_desc,
void* var);
};
const std::map<const std::string, std::pair<std::vector<int>, std::vector<int>>> const std::map<const std::string, std::pair<std::vector<int>, std::vector<int>>>
TransPermMap = { TransPermMap = {
// trans_mode, (forward_perm, backward_perm) // trans_mode, (forward_perm, backward_perm)
......
...@@ -97,4 +97,6 @@ class OneHotV2MLUKernel : public framework::OpKernel<T> { ...@@ -97,4 +97,6 @@ class OneHotV2MLUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(one_hot_v2, ops::OneHotV2MLUKernel<int32_t>); REGISTER_OP_MLU_KERNEL(one_hot_v2,
ops::OneHotV2MLUKernel<int32_t>,
ops::OneHotV2MLUKernel<int64_t>);
...@@ -291,11 +291,38 @@ class AdamWMLUKernel : public AdamMLUKernel<T> { ...@@ -291,11 +291,38 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
skip_update = skip_update_vec[0]; skip_update = skip_update_vec[0];
} }
bool with_decay = ctx.Attr<bool>("with_decay"); bool with_decay = ctx.Attr<bool>("with_decay");
const bool multi_precision = ctx.Attr<bool>("multi_precision");
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
auto* master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
const auto* master_param = ctx.Input<LoDTensor>("MasterParam");
VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay; VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay;
if (!skip_update && with_decay) { if (!skip_update && with_decay) {
if (ctx.HasInput("MasterParam")) { auto* param = ctx.Input<LoDTensor>("Param");
PADDLE_THROW(platform::errors::Unimplemented( MLUCnnlTensorDesc param_desc(*param);
"Master Param is not supported on MLU")); if (multi_precision) {
VLOG(3) << "[adamw] multi_precision, cast masterparam to param.";
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(
has_master,
true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
// cast masterparam (fp32) to param (fp16), then paramout (fp16) to
// masterparamout (fp32)
MLUCnnlTensorDesc master_param_desc(*master_param);
cnnlCastDataType_t cast_type = GetCastDataType(
framework::TransToProtoVarType(master_param->dtype()),
framework::TransToProtoVarType(param->dtype()));
MLUCnnl::Cast(ctx,
cast_type,
master_param_desc.get(),
GetBasePtr(master_param),
param_desc.get(),
const_cast<void*>(GetBasePtr(param)));
} else { } else {
const auto* param_var = ctx.InputVar("Param"); const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(),
...@@ -305,13 +332,12 @@ class AdamWMLUKernel : public AdamMLUKernel<T> { ...@@ -305,13 +332,12 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
"but the received is %s", "but the received is %s",
ctx.InputNames("Param").front(), ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()))); framework::ToTypeName(param_var->Type())));
auto* param = ctx.Input<LoDTensor>("Param");
auto* lr = ctx.Input<LoDTensor>("LearningRate"); auto* lr = ctx.Input<LoDTensor>("LearningRate");
float coeff = ctx.Attr<float>("coeff"); float coeff = ctx.Attr<float>("coeff");
// update param with decay coeff: mul(-1 * lr, coeff * param) + param // update param with decay coeff: mul(-1 * lr, coeff * param) + param
MLUCnnlTensorDesc lr_desc(*lr); MLUCnnlTensorDesc lr_desc(*lr);
MLUCnnlTensorDesc param_desc(*param);
MLUCnnlOpTensorDesc mul_op_desc( MLUCnnlOpTensorDesc mul_op_desc(
CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN); CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
...@@ -330,9 +356,244 @@ class AdamWMLUKernel : public AdamMLUKernel<T> { ...@@ -330,9 +356,244 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
} }
} }
AdamMLUKernel<T>::Compute(ctx); AdamMLUKernel<T>::Compute(ctx);
if (multi_precision) {
VLOG(3) << "[adamw] multi_precision, cast paramout to masterparamout.";
// cast paramout to masterparamout
master_param_out->mutable_data<float>(ctx.GetPlace());
cnnlCastDataType_t cast_type = GetCastDataType(
framework::TransToProtoVarType(param_out->dtype()),
framework::TransToProtoVarType(master_param_out->dtype()));
MLUCnnlTensorDesc param_out_desc(*param_out);
MLUCnnlTensorDesc master_param_out_desc(*master_param_out);
MLUCnnl::Cast(ctx,
cast_type,
param_out_desc.get(),
GetBasePtr(param_out),
master_param_out_desc.get(),
GetBasePtr(master_param_out));
}
} }
}; };
template <typename T>
class MergedAdamMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Get inputs and outputs
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
auto mom1s = ctx.MultiInput<framework::Tensor>("Moment1");
auto mom2s = ctx.MultiInput<framework::Tensor>("Moment2");
auto beta1_pows = ctx.MultiInput<framework::Tensor>("Beta1Pow");
auto beta2_pows = ctx.MultiInput<framework::Tensor>("Beta2Pow");
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto param_outs = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto mom1_outs = ctx.MultiOutput<framework::Tensor>("Moment1Out");
auto mom2_outs = ctx.MultiOutput<framework::Tensor>("Moment2Out");
auto beta1_pow_outs = ctx.MultiOutput<framework::Tensor>("Beta1PowOut");
auto beta2_pow_outs = ctx.MultiOutput<framework::Tensor>("Beta2PowOut");
// Check validation of inputs and outputs
size_t param_num = params.size();
PADDLE_ENFORCE_EQ(param_num,
param_outs.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
param_outs.size(),
param_num));
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(),
1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(
*skip_update_tensor, ctx.device_context(), &skip_update_vec);
ctx.device_context().Wait();
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "MergedAdam skip update";
for (size_t i = 0; i < param_num; ++i) {
framework::TensorCopy(
*params[i],
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
param_outs[i]);
framework::TensorCopy(
*mom1s[i],
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
mom1_outs[i]);
framework::TensorCopy(
*mom2s[i],
ctx.GetPlace(),
ctx.template device_context<platform::MLUDeviceContext>(),
mom2_outs[i]);
framework::TensorCopy(
*beta1_pows[i],
beta1_pows[i]->place(),
ctx.template device_context<platform::MLUDeviceContext>(),
beta1_pow_outs[i]);
framework::TensorCopy(
*beta2_pows[i],
beta2_pows[i]->place(),
ctx.template device_context<platform::MLUDeviceContext>(),
beta2_pow_outs[i]);
}
return;
}
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
// Get beta1, beta2 and epsilon from attribute.
const Tensor* beta1_tensor = nullptr;
const Tensor* beta2_tensor = nullptr;
const Tensor* epsilon_tensor = nullptr;
Tensor beta1_tmp(experimental::DataType::FLOAT32);
Tensor beta2_tmp(experimental::DataType::FLOAT32);
Tensor epsilon_tmp(experimental::DataType::FLOAT32);
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
beta1_tmp.mutable_data<T>({1}, ctx.GetPlace());
beta2_tmp.mutable_data<T>({1}, ctx.GetPlace());
epsilon_tmp.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc beta1_tmp_desc(beta1_tmp);
MLUCnnlTensorDesc beta2_tmp_desc(beta2_tmp);
MLUCnnlTensorDesc epsilon_tmp_desc(epsilon_tmp);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta1,
beta1_tmp_desc.get(),
GetBasePtr(&beta1_tmp));
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta2,
beta2_tmp_desc.get(),
GetBasePtr(&beta2_tmp));
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&epsilon,
epsilon_tmp_desc.get(),
GetBasePtr(&epsilon_tmp));
beta1_tensor = &beta1_tmp;
beta2_tensor = &beta2_tmp;
epsilon_tensor = &epsilon_tmp;
// Loop to compute
for (size_t i = 0; i < param_num; ++i) {
VLOG(4) << "[MergedAdam] loop: " << i;
param_outs[i]->ShareDataWith(*params[i]);
mom1_outs[i]->ShareDataWith(*mom1s[i]);
mom2_outs[i]->ShareDataWith(*mom2s[i]);
LoDTensor beta1_pow_tmp;
LoDTensor beta2_pow_tmp;
if (beta1_pows[i]->place() == platform::CPUPlace()) {
T beta1 = *beta1_pows[i]->data<T>();
beta1_pow_tmp.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc beta1_pow_tmp_desc(beta1_pow_tmp);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta1,
beta1_pow_tmp_desc.get(),
GetBasePtr(&beta1_pow_tmp));
beta1_pows[i] = &beta1_pow_tmp;
}
if (beta2_pows[i]->place() == platform::CPUPlace()) {
T beta2 = *beta2_pows[i]->data<T>();
beta2_pow_tmp.mutable_data<T>({1}, ctx.GetPlace());
MLUCnnlTensorDesc beta2_pow_tmp_desc(beta2_pow_tmp);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&beta2,
beta2_pow_tmp_desc.get(),
GetBasePtr(&beta2_pow_tmp));
beta2_pows[i] = &beta2_pow_tmp;
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pows[i]->numel()
<< "beta2_pow.numel() : " << beta2_pows[i]->numel();
VLOG(3) << "param.numel(): " << params[i]->numel();
PADDLE_ENFORCE_EQ(beta1_pow_outs[i]->numel(),
1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_outs[i]->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_outs[i]->numel(),
1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_outs[i]->numel()));
MLUCnnlTensorDesc param_desc(*params[i]);
MLUCnnlTensorDesc mom1_desc(*mom1s[i]);
MLUCnnlTensorDesc mom2_desc(*mom2s[i]);
MLUCnnlTensorDesc grad_desc(*grads[i]);
MLUCnnl::ApplyAdam(ctx,
param_desc.get(),
GetBasePtr(param_outs[i]),
mom1_desc.get(),
GetBasePtr(mom1_outs[i]),
mom2_desc.get(),
GetBasePtr(mom2_outs[i]),
grad_desc.get(),
GetBasePtr(grads[i]),
GetBasePtr(lrs[i]),
GetBasePtr(beta1_tensor),
GetBasePtr(beta2_tensor),
GetBasePtr(beta1_pows[i]),
GetBasePtr(beta2_pows[i]),
GetBasePtr(epsilon_tensor),
/*use_nesterov*/ false);
if (!use_global_beta_pow) {
beta1_pow_outs[i]->mutable_data<T>(ctx.GetPlace());
beta2_pow_outs[i]->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc beta1_desc(*beta1_tensor);
MLUCnnlOpTensorDesc mul_op_desc(
CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx,
mul_op_desc.get(),
beta1_desc.get(),
GetBasePtr(beta1_pows[i]),
beta1_desc.get(),
GetBasePtr(beta1_tensor),
beta1_desc.get(),
GetBasePtr(beta1_pow_outs[i]),
ToCnnlDataType<T>());
MLUCnnl::OpTensor(ctx,
mul_op_desc.get(),
beta1_desc.get(),
GetBasePtr(beta2_pows[i]),
beta1_desc.get(),
GetBasePtr(beta2_tensor),
beta1_desc.get(),
GetBasePtr(beta2_pow_outs[i]),
ToCnnlDataType<T>());
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -346,3 +607,7 @@ REGISTER_OP_MLU_KERNEL(adam, ...@@ -346,3 +607,7 @@ REGISTER_OP_MLU_KERNEL(adam,
REGISTER_OP_MLU_KERNEL(adamw, REGISTER_OP_MLU_KERNEL(adamw,
ops::AdamWMLUKernel<float>, ops::AdamWMLUKernel<float>,
ops::AdamWMLUKernel<plat::float16>); ops::AdamWMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(merged_adam,
ops::MergedAdamMLUKernel<float>,
ops::MergedAdamMLUKernel<plat::float16>);
...@@ -141,10 +141,9 @@ class MLUPoolOpKernel : public framework::OpKernel<T> { ...@@ -141,10 +141,9 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
handle, pool_mode, out_w, out_h, &extra_input_size); handle, pool_mode, out_w, out_h, &extra_input_size);
if (extra_input_size > 0) { if (extra_input_size > 0) {
phi::CPUContext cpu_ctx; framework::Tensor extra_host_tensor;
framework::Tensor extra_host_tensor = extra_host_tensor.mutable_data<int8_t>(
ctx.AllocateTmpTensor<int8_t, phi::CPUContext>( {static_cast<int64_t>(extra_input_size)}, platform::CPUPlace());
{static_cast<int64_t>(extra_input_size)}, cpu_ctx);
cnnlInitPoolingExtraInput(handle, cnnlInitPoolingExtraInput(handle,
pool_desc.get(), pool_desc.get(),
trans_in_x_desc.get(), trans_in_x_desc.get(),
......
...@@ -92,6 +92,112 @@ class ReduceMaxMLUKernel : public framework::OpKernel<T> { ...@@ -92,6 +92,112 @@ class ReduceMaxMLUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class ReduceMaxGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Input<Tensor>("Out");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto reduce_dims = context.Attr<std::vector<int>>("dim");
bool reduce_all = context.Attr<bool>("reduce_all");
int in_dtype = context.Attr<int>("in_dtype");
PADDLE_ENFORCE_EQ(
in_dtype == -1,
true,
platform::errors::InvalidArgument(
"MLU only support in_dtype == -1 in reduce_max_grad op."));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
x_grad->mutable_data<T>(context.GetPlace());
auto place = context.GetPlace();
// broadcast
auto x_dims_vec = phi::vectorize(x->dims());
if (reduce_all) {
reduce_dims.clear();
for (size_t d = 0; d < x_dims_vec.size(); ++d) {
reduce_dims.push_back(static_cast<int>(d));
}
}
Tensor tmp_out, tmp_out_grad;
auto tmp_out_dims_vec = x_dims_vec;
for (auto d : reduce_dims) {
if (d < 0) {
d += x_dims_vec.size();
}
tmp_out_dims_vec[d] = 1;
}
tmp_out.ShareDataWith(*out);
tmp_out.Resize(phi::make_ddim(tmp_out_dims_vec));
tmp_out_grad.ShareDataWith(*out_grad);
tmp_out_grad.Resize(phi::make_ddim(tmp_out_dims_vec));
Tensor transformed_out(x->type());
transformed_out.Resize(phi::make_ddim(x_dims_vec));
transformed_out.mutable_data<T>(place);
MLUCnnlTensorDesc tmp_out_desc(tmp_out);
MLUCnnlTensorDesc transformed_out_desc(transformed_out);
MLUCnnl::BroadcastTo(context,
tmp_out_desc.get(),
GetBasePtr(&tmp_out),
transformed_out_desc.get(),
GetBasePtr(&transformed_out));
Tensor transformed_out_grad(x->type());
transformed_out_grad.Resize(phi::make_ddim(x_dims_vec));
transformed_out_grad.mutable_data<T>(place);
MLUCnnlTensorDesc tmp_out_grad_desc(tmp_out_grad);
MLUCnnlTensorDesc transformed_out_grad_desc(transformed_out_grad);
MLUCnnl::BroadcastTo(context,
tmp_out_grad_desc.get(),
GetBasePtr(&tmp_out_grad),
transformed_out_grad_desc.get(),
GetBasePtr(&transformed_out_grad));
// compare
Tensor equal_cond;
equal_cond.mutable_data<bool>(x_grad->dims(), place);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc equal_cond_desc(equal_cond);
MLUCnnl::Logic(context,
CNNL_LOGIC_OP_EQ,
x_desc.get(),
GetBasePtr(x),
transformed_out_desc.get(),
GetBasePtr(&transformed_out),
equal_cond_desc.get(),
GetBasePtr(&equal_cond));
// select
Tensor t_zero;
t_zero.mutable_data<T>(x_grad->dims(), place);
FillMLUTensorWithHostValue<T>(context, static_cast<T>(0), &t_zero);
t_zero.Resize(x_grad->dims());
MLUCnnlTensorDesc t_zero_desc(t_zero);
MLUCnnlTensorDesc x_grad_desc(*x_grad);
MLUCnnl::Select(context,
equal_cond_desc.get(),
GetBasePtr(&equal_cond),
transformed_out_grad_desc.get(),
GetBasePtr(&transformed_out_grad),
t_zero_desc.get(),
GetBasePtr(&t_zero),
x_grad_desc.get(),
GetBasePtr(x_grad));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -102,3 +208,7 @@ REGISTER_OP_MLU_KERNEL(reduce_max, ...@@ -102,3 +208,7 @@ REGISTER_OP_MLU_KERNEL(reduce_max,
ops::ReduceMaxMLUKernel<float>, ops::ReduceMaxMLUKernel<float>,
ops::ReduceMaxMLUKernel<plat::float16>, ops::ReduceMaxMLUKernel<plat::float16>,
ops::ReduceMaxMLUKernel<int>); ops::ReduceMaxMLUKernel<int>);
REGISTER_OP_MLU_KERNEL(reduce_max_grad,
ops::ReduceMaxGradMLUKernel<float>,
ops::ReduceMaxGradMLUKernel<plat::float16>,
ops::ReduceMaxGradMLUKernel<int>);
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = phi::DenseTensor;
using Variable = framework::Variable;
using LoDTensorArray = framework::LoDTensorArray;
using DDim = framework::DDim;
static void ProcessStridedSliceParams( static void ProcessStridedSliceParams(
const std::vector<int>& axes, const std::vector<int>& axes,
const DDim& input_dims, const DDim& input_dims,
......
...@@ -28,11 +28,13 @@ MLUContext::MLUContext(const MLUPlace& place, const int priority) { ...@@ -28,11 +28,13 @@ MLUContext::MLUContext(const MLUPlace& place, const int priority) {
MLUDeviceGuard guard(place_.device); MLUDeviceGuard guard(place_.device);
stream_.reset(new stream::MLUStream(place_, priority)); stream_.reset(new stream::MLUStream(place_, priority));
InitCNNLContext(); InitCNNLContext();
InitMLUOPContext();
} }
MLUContext::~MLUContext() { MLUContext::~MLUContext() {
MLUDeviceGuard guard(place_.device); MLUDeviceGuard guard(place_.device);
DestoryCNNLContext(); DestoryCNNLContext();
DestoryMLUOPContext();
} }
MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
...@@ -41,6 +43,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { ...@@ -41,6 +43,7 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
driver_version_ = GetMLUDriverVersion(place_.device); driver_version_ = GetMLUDriverVersion(place_.device);
runtime_version_ = GetMLURuntimeVersion(place_.device); runtime_version_ = GetMLURuntimeVersion(place_.device);
cnnl_version_ = GetMLUCnnlVersion(place_.device); cnnl_version_ = GetMLUCnnlVersion(place_.device);
mluOp_version_ = GetMLUOpVersion(place_.device);
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "Please NOTE: device: " << static_cast<int>(place_.device) << "Please NOTE: device: " << static_cast<int>(place_.device)
...@@ -51,7 +54,9 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) { ...@@ -51,7 +54,9 @@ MLUDeviceContext::MLUDeviceContext(MLUPlace place) : place_(place) {
<< ", Runtime API Version: " << runtime_version_ / 10000 << "." << ", Runtime API Version: " << runtime_version_ / 10000 << "."
<< (runtime_version_ / 100) % 100 << "." << runtime_version_ % 100 << (runtime_version_ / 100) % 100 << "." << runtime_version_ % 100
<< ", Cnnl API Version: " << cnnl_version_ / 10000 << "." << ", Cnnl API Version: " << cnnl_version_ / 10000 << "."
<< (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100; << (cnnl_version_ / 100) % 100 << "." << cnnl_version_ % 100
<< ", MluOp API Version: " << mluOp_version_ / 10000 << "."
<< (mluOp_version_ / 100) % 100 << "." << mluOp_version_ % 100;
default_ctx_.reset(new MLUContext(place_)); default_ctx_.reset(new MLUContext(place_));
} }
...@@ -70,6 +75,10 @@ mluCnnlHandle MLUDeviceContext::cnnl_handle() const { ...@@ -70,6 +75,10 @@ mluCnnlHandle MLUDeviceContext::cnnl_handle() const {
return context()->CnnlHandle(); return context()->CnnlHandle();
} }
mluOpHandle MLUDeviceContext::mluOp_handle() const {
return context()->MluOpHandle();
}
mluStream MLUDeviceContext::stream() const { return context()->RawStream(); } mluStream MLUDeviceContext::stream() const { return context()->RawStream(); }
#endif #endif
......
...@@ -53,12 +53,19 @@ class MLUContext { ...@@ -53,12 +53,19 @@ class MLUContext {
const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; } const mluCnnlHandle& CnnlHandle() const { return cnnl_handle_; }
const mluOpHandle& MluOpHandle() const { return mluOp_handle_; }
private: private:
void InitCNNLContext() { void InitCNNLContext() {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_)); PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreate(&cnnl_handle_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream())); PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetQueue(cnnl_handle_, RawStream()));
} }
void InitMLUOPContext() {
PADDLE_ENFORCE_MLU_SUCCESS(mluOpCreate(&mluOp_handle_));
PADDLE_ENFORCE_MLU_SUCCESS(mluOpSetQueue(mluOp_handle_, RawStream()));
}
void DestoryCNNLContext() { void DestoryCNNLContext() {
if (cnnl_handle_) { if (cnnl_handle_) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_)); PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroy(cnnl_handle_));
...@@ -66,10 +73,18 @@ class MLUContext { ...@@ -66,10 +73,18 @@ class MLUContext {
cnnl_handle_ = nullptr; cnnl_handle_ = nullptr;
} }
void DestoryMLUOPContext() {
if (mluOp_handle_) {
PADDLE_ENFORCE_MLU_SUCCESS(mluOpDestroy(mluOp_handle_));
}
mluOp_handle_ = nullptr;
}
MLUPlace place_; MLUPlace place_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
std::unique_ptr<stream::MLUStream> stream_; std::unique_ptr<stream::MLUStream> stream_;
mluCnnlHandle cnnl_handle_; mluCnnlHandle cnnl_handle_;
mluOpHandle mluOp_handle_;
DISABLE_COPY_AND_ASSIGN(MLUContext); DISABLE_COPY_AND_ASSIGN(MLUContext);
}; };
...@@ -89,6 +104,9 @@ class MLUDeviceContext : public DeviceContext { ...@@ -89,6 +104,9 @@ class MLUDeviceContext : public DeviceContext {
/*! \brief Return cnnl handle in the device context. */ /*! \brief Return cnnl handle in the device context. */
mluCnnlHandle cnnl_handle() const; mluCnnlHandle cnnl_handle() const;
/*! \brief Return mluOp handle in the device context. */
mluOpHandle mluOp_handle() const;
/*! \brief Return mlu stream in the device context. */ /*! \brief Return mlu stream in the device context. */
mluStream stream() const; mluStream stream() const;
...@@ -135,6 +153,7 @@ class MLUDeviceContext : public DeviceContext { ...@@ -135,6 +153,7 @@ class MLUDeviceContext : public DeviceContext {
int driver_version_; int driver_version_;
int runtime_version_; int runtime_version_;
int cnnl_version_; int cnnl_version_;
int mluOp_version_;
MLUPlace place_; MLUPlace place_;
std::shared_ptr<MLUContext> default_ctx_; std::shared_ptr<MLUContext> default_ctx_;
......
...@@ -41,6 +41,7 @@ struct MLUStatusType {}; ...@@ -41,6 +41,7 @@ struct MLUStatusType {};
DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT); DEFINE_MLU_STATUS_TYPE(cnrtStatus, cnrtSuccess, CNRT);
DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL); DEFINE_MLU_STATUS_TYPE(cnnlStatus, CNNL_STATUS_SUCCESS, CNNL);
DEFINE_MLU_STATUS_TYPE(mluOpStatus, MLUOP_STATUS_SUCCESS, MLUOP);
DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN); DEFINE_MLU_STATUS_TYPE(cnStatus, CN_SUCCESS, CN);
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL); DEFINE_MLU_STATUS_TYPE(cnclStatus, CNCL_RET_SUCCESS, CNCL);
...@@ -68,6 +69,15 @@ inline std::string build_mlu_error_msg(cnnlStatus stat) { ...@@ -68,6 +69,15 @@ inline std::string build_mlu_error_msg(cnnlStatus stat) {
return sout.str(); return sout.str();
} }
/*************** MLU OP ERROR ***************/
inline bool is_error(mluOpStatus stat) { return stat != MLUOP_STATUS_SUCCESS; }
inline std::string build_mlu_error_msg(mluOpStatus stat) {
std::ostringstream sout;
sout << "MLU OP error(" << stat << "), " << mluOpGetErrorString(stat) << ". ";
return sout.str();
}
/*************** CN API ERROR ***************/ /*************** CN API ERROR ***************/
inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; } inline bool is_error(cnStatus stat) { return stat != CN_SUCCESS; }
......
...@@ -126,6 +126,13 @@ int GetMLUCnnlVersion(int id) { ...@@ -126,6 +126,13 @@ int GetMLUCnnlVersion(int id) {
return x * 10000 + y * 100 + z; return x * 10000 + y * 100 + z;
} }
int GetMLUOpVersion(int id) {
CheckDeviceId(id);
int x, y, z;
mluOpGetLibVersion(&x, &y, &z);
return x * 10000 + y * 100 + z;
}
int GetMLUCurrentDeviceId() { int GetMLUCurrentDeviceId() {
int device_id; int device_id;
PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id)); PADDLE_ENFORCE_MLU_SUCCESS(cnrtGetDevice(&device_id));
......
...@@ -16,10 +16,11 @@ limitations under the License. */ ...@@ -16,10 +16,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
#include <cn_api.h> #include <cn_api.h>
#include <cndrv_id.h>
#include <cnnl.h> #include <cnnl.h>
#include <cnpapi.h> #include <cnpapi.h>
#include <cnpapi_cndrv_id.h>
#include <cnrt.h> #include <cnrt.h>
#include <mlu_op.h>
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
#include <cncl.h> #include <cncl.h>
#endif #endif
...@@ -30,11 +31,13 @@ namespace paddle { ...@@ -30,11 +31,13 @@ namespace paddle {
using cnStatus = CNresult; using cnStatus = CNresult;
using cnrtStatus = cnrtRet_t; using cnrtStatus = cnrtRet_t;
using cnnlStatus = cnnlStatus_t; using cnnlStatus = cnnlStatus_t;
using mluOpStatus = mluOpStatus_t;
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
using cnclStatus = cnclResult_t; using cnclStatus = cnclResult_t;
#endif #endif
using mluStream = cnrtQueue_t; using mluStream = cnrtQueue_t;
using mluCnnlHandle = cnnlHandle_t; using mluCnnlHandle = cnnlHandle_t;
using mluOpHandle = mluOpHandle_t;
using mluEventHandle = cnrtNotifier_t; using mluEventHandle = cnrtNotifier_t;
using mluDeviceHandle = CNdev; using mluDeviceHandle = CNdev;
...@@ -49,6 +52,9 @@ int GetMLURuntimeVersion(int id); ...@@ -49,6 +52,9 @@ int GetMLURuntimeVersion(int id);
//! Get the cnnl version of the ith MLU. //! Get the cnnl version of the ith MLU.
int GetMLUCnnlVersion(int id); int GetMLUCnnlVersion(int id);
//! Get the mluOp version of the ith MLU.
int GetMLUOpVersion(int id);
//! Get the total number of MLU devices in system. //! Get the total number of MLU devices in system.
int GetMLUDeviceCount(); int GetMLUDeviceCount();
......
...@@ -29,7 +29,10 @@ ...@@ -29,7 +29,10 @@
#include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h" #include "paddle/fluid/platform/profiler/custom_device/custom_tracer.h"
#include "paddle/fluid/platform/profiler/extra_info.h" #include "paddle/fluid/platform/profiler/extra_info.h"
#include "paddle/fluid/platform/profiler/host_tracer.h" #include "paddle/fluid/platform/profiler/host_tracer.h"
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h" #include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h"
#endif
#include "paddle/fluid/platform/profiler/trace_event_collector.h" #include "paddle/fluid/platform/profiler/trace_event_collector.h"
#include "paddle/fluid/platform/profiler/utils.h" #include "paddle/fluid/platform/profiler/utils.h"
...@@ -80,9 +83,11 @@ Profiler::Profiler(const ProfilerOptions& options, ...@@ -80,9 +83,11 @@ Profiler::Profiler(const ProfilerOptions& options,
if (trace_switch.test(kProfileGPUOptionBit)) { if (trace_switch.test(kProfileGPUOptionBit)) {
tracers_.emplace_back(&CudaTracer::GetInstance(), false); tracers_.emplace_back(&CudaTracer::GetInstance(), false);
} }
#ifdef PADDLE_WITH_MLU
if (trace_switch.test(kProfileMLUOptionBit)) { if (trace_switch.test(kProfileMLUOptionBit)) {
tracers_.emplace_back(&MluTracer::GetInstance(), false); tracers_.emplace_back(&MluTracer::GetInstance(), false);
} }
#endif
if (trace_switch.test(kProfileCustomDeviceOptionBit)) { if (trace_switch.test(kProfileCustomDeviceOptionBit)) {
for (const auto& dev_type : custom_device_types) { for (const auto& dev_type : custom_device_types) {
tracers_.emplace_back(&CustomTracer::GetInstance(dev_type), false); tracers_.emplace_back(&CustomTracer::GetInstance(dev_type), false);
......
...@@ -34,7 +34,10 @@ import unittest ...@@ -34,7 +34,10 @@ import unittest
from multiprocessing import Process from multiprocessing import Process
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from functools import reduce from functools import reduce
from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main from test_sync_batch_norm_base_mlu import (
TestSyncBatchNormRunnerBase,
runtime_main,
)
from op_test import OpTest, _set_use_system_allocator from op_test import OpTest, _set_use_system_allocator
from test_sync_batch_norm_op import create_or_get_tensor from test_sync_batch_norm_op import create_or_get_tensor
...@@ -44,11 +47,11 @@ paddle.enable_static() ...@@ -44,11 +47,11 @@ paddle.enable_static()
class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
def __init__(self): def __init__(self):
self.global_ring_id = 0 self.global_ring_id = 0
self.dtype = np.float32 self.dtype = np.float32
self.bn_dtype = np.float32
self.N = 8 self.N = 8
self.C = 16 self.C = 16
self.H = 32 self.H = 32
...@@ -56,29 +59,36 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): ...@@ -56,29 +59,36 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
self.dshape = [self.N, self.C, self.H, self.W] self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-3 self.atol = 1e-3
def get_model(self, def get_model(
main, self,
startup, main,
place, startup,
layout, place,
seed, layout,
sync_bn=False, seed,
only_forward=False): sync_bn=False,
only_forward=False,
):
"""Build program.""" """Build program."""
use_cudnn = False use_cudnn = False
with fluid.unique_name.guard(): with fluid.unique_name.guard():
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
data = fluid.layers.data(name='input', data = fluid.layers.data(
shape=self.dshape, name='input',
dtype=self.dtype, shape=self.dshape,
append_batch_size=False) dtype=self.dtype,
append_batch_size=False,
)
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=data, input=data,
num_filters=32, num_filters=32,
filter_size=1, filter_size=1,
param_attr=fluid.ParamAttr(name='conv2d_weight'), param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False, bias_attr=False,
use_cudnn=use_cudnn) use_cudnn=use_cudnn,
)
if self.bn_dtype == np.float16:
conv = fluid.layers.cast(conv, 'float16')
bn = fluid.layers.batch_norm( bn = fluid.layers.batch_norm(
conv, conv,
param_attr=fluid.ParamAttr(name='bn_scale'), param_attr=fluid.ParamAttr(name='bn_scale'),
...@@ -86,9 +96,10 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): ...@@ -86,9 +96,10 @@ class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase):
moving_mean_name='bn_moving_mean', moving_mean_name='bn_moving_mean',
moving_variance_name='bn_moving_variance', moving_variance_name='bn_moving_variance',
data_layout=layout, data_layout=layout,
is_test=only_forward) is_test=only_forward,
# if self.dtype == np.float16: )
# bn = fluid.layers.cast(bn, 'float32') if self.bn_dtype == np.float16:
bn = fluid.layers.cast(bn, 'float32')
sigmoid = fluid.layers.sigmoid(bn) sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid) out = fluid.layers.reduce_sum(sigmoid)
# if not sync_bn: # if not sync_bn:
......
...@@ -41,10 +41,10 @@ def DataTypeCast(date_type): ...@@ -41,10 +41,10 @@ def DataTypeCast(date_type):
class TestCollectiveAPIRunnerBase(object): class TestCollectiveAPIRunnerBase(object):
def get_model(self, train_prog, startup_prog, rank, indata=None): def get_model(self, train_prog, startup_prog, rank, indata=None):
raise NotImplementedError( raise NotImplementedError(
"get model should be implemented by child class.") "get model should be implemented by child class."
)
def run_trainer(self, args): def run_trainer(self, args):
train_prog = fluid.Program() train_prog = fluid.Program()
...@@ -66,12 +66,12 @@ class TestCollectiveAPIRunnerBase(object): ...@@ -66,12 +66,12 @@ class TestCollectiveAPIRunnerBase(object):
fetch_list = [] fetch_list = []
for elem in result: for elem in result:
fetch_list.append(elem.name) fetch_list.append(elem.name)
out = exe.run(train_prog, out = exe.run(
feed={'tindata': indata}, train_prog, feed={'tindata': indata}, fetch_list=fetch_list
fetch_list=fetch_list) )
else: else:
out = self.get_model(train_prog, startup_prog, rank, indata) out = self.get_model(train_prog, startup_prog, rank, indata)
#print(out, sys.stderr) # print(out, sys.stderr)
sys.stdout.buffer.write(pickle.dumps(out)) sys.stdout.buffer.write(pickle.dumps(out))
...@@ -96,19 +96,20 @@ from contextlib import closing ...@@ -96,19 +96,20 @@ from contextlib import closing
class TestDistBase(unittest.TestCase): class TestDistBase(unittest.TestCase):
def setUp(self): def setUp(self):
self._port_set = set() self._port_set = set()
self._trainers = 2 self._trainers = 2
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(),
self._find_free_port(),
)
self._python_interp = sys.executable self._python_interp = sys.executable
def _find_free_port(self): def _find_free_port(self):
def __free_port(): def __free_port():
with closing(socket.socket(socket.AF_INET, with closing(
socket.SOCK_STREAM)) as s: socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as s:
s.bind(('', 0)) s.bind(('', 0))
return s.getsockname()[1] return s.getsockname()[1]
...@@ -121,13 +122,13 @@ class TestDistBase(unittest.TestCase): ...@@ -121,13 +122,13 @@ class TestDistBase(unittest.TestCase):
def _run_cluster(self, model_file, envs): def _run_cluster(self, model_file, envs):
worker_endpoints = self._ps_endpoints.split(",") worker_endpoints = self._ps_endpoints.split(",")
w0_ep, w1_ep = worker_endpoints w0_ep, w1_ep = worker_endpoints
#print("w0_ep:",w0_ep," w1_ep:",w1_ep) # print("w0_ep:",w0_ep," w1_ep:",w1_ep)
env0 = { env0 = {
"FLAGS_selected_mlus": "0", "FLAGS_selected_mlus": "0",
"PADDLE_TRAINER_ID": "0", "PADDLE_TRAINER_ID": "0",
"PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep "PADDLE_CURRENT_ENDPOINT": w0_ep,
} }
env1 = { env1 = {
...@@ -135,9 +136,9 @@ class TestDistBase(unittest.TestCase): ...@@ -135,9 +136,9 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ID": "1", "PADDLE_TRAINER_ID": "1",
"PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep "PADDLE_CURRENT_ENDPOINT": w1_ep,
} }
#update environment # update environment
env0.update(envs) env0.update(envs)
env1.update(envs) env1.update(envs)
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
...@@ -148,16 +149,20 @@ class TestDistBase(unittest.TestCase): ...@@ -148,16 +149,20 @@ class TestDistBase(unittest.TestCase):
tr1_cmd = tr_cmd % (self._python_interp, model_file) tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w") tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")
tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w") tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w")
#print(tr0_cmd) # print(tr0_cmd)
tr0_proc = subprocess.Popen(tr0_cmd.strip().split(), tr0_proc = subprocess.Popen(
stdout=subprocess.PIPE, tr0_cmd.strip().split(),
stderr=tr0_pipe, stdout=subprocess.PIPE,
env=env0) stderr=tr0_pipe,
env=env0,
tr1_proc = subprocess.Popen(tr0_cmd.strip().split(), )
stdout=subprocess.PIPE,
stderr=tr1_pipe, tr1_proc = subprocess.Popen(
env=env1) tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1,
)
tr0_out, tr0_err = tr0_proc.communicate() tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate()
...@@ -170,17 +175,23 @@ class TestDistBase(unittest.TestCase): ...@@ -170,17 +175,23 @@ class TestDistBase(unittest.TestCase):
sys.stderr.write('trainer 0 stderr file: %s\n' % f.read()) sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f: with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f:
sys.stderr.write('trainer 1 stderr file: %s\n' % f.read()) sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
return pickle.loads(tr0_out), pickle.loads( return (
tr1_out), tr0_proc.pid, tr1_proc.pid pickle.loads(tr0_out),
pickle.loads(tr1_out),
def check_with_place(self, tr0_proc.pid,
model_file, tr1_proc.pid,
col_type, )
data_type,
path_id="0", def check_with_place(
static_mode="1", self,
check_error_log=False, model_file,
need_envs={}): col_type,
data_type,
path_id="0",
static_mode="1",
check_error_log=False,
need_envs={},
):
required_envs = { required_envs = {
"FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_eager_delete_tensor_gb": "0.0", "FLAGS_eager_delete_tensor_gb": "0.0",
...@@ -194,7 +205,7 @@ class TestDistBase(unittest.TestCase): ...@@ -194,7 +205,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_WITH_GLOO": '0', "PADDLE_WITH_GLOO": '0',
"BACKEND": "cncl", "BACKEND": "cncl",
"PATH_ID": path_id, "PATH_ID": path_id,
"DATA_TYPE": data_type "DATA_TYPE": data_type,
} }
required_envs.update(need_envs) required_envs.update(need_envs)
if check_error_log: if check_error_log:
...@@ -202,7 +213,8 @@ class TestDistBase(unittest.TestCase): ...@@ -202,7 +213,8 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs["GLOO_LOG_LEVEL"] = "TRACE" required_envs["GLOO_LOG_LEVEL"] = "TRACE"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster( tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs) model_file, required_envs
)
np_data_type = DataTypeCast(data_type) np_data_type = DataTypeCast(data_type)
np.random.seed(pid0) np.random.seed(pid0)
input1 = np.random.random((10, 1000)).astype(np_data_type) input1 = np.random.random((10, 1000)).astype(np_data_type)
...@@ -210,21 +222,19 @@ class TestDistBase(unittest.TestCase): ...@@ -210,21 +222,19 @@ class TestDistBase(unittest.TestCase):
input2 = np.random.random((10, 1000)).astype(np_data_type) input2 = np.random.random((10, 1000)).astype(np_data_type)
if col_type == "broadcast": if col_type == "broadcast":
need_result = input2 need_result = input2
np.testing.assert_allclose(tr0_out, need_result) np.testing.assert_allclose(tr0_out[0], need_result)
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
elif col_type == "allreduce": elif col_type == "allreduce":
need_result = input1 + input2 need_result = input1 + input2
np.testing.assert_allclose(tr0_out, np.testing.assert_allclose(
need_result, tr0_out[0], need_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05) np.testing.assert_allclose(
np.testing.assert_allclose(tr1_out, tr1_out[0], need_result, rtol=1e-05, atol=1e-05
need_result, )
rtol=1e-05,
atol=1e-05)
elif col_type == "reduce": elif col_type == "reduce":
need_result = input1 + input2 need_result = input1 + input2
np.testing.assert_allclose(tr0_out, need_result) np.testing.assert_allclose(tr0_out[0], need_result)
elif col_type == "allgather": elif col_type == "allgather":
need_result = np.vstack((input1, input2)) need_result = np.vstack((input1, input2))
tr_out0 = np.vstack((tr0_out[0], tr0_out[1])) tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
......
...@@ -53,10 +53,10 @@ def DataTypeCast(date_type): ...@@ -53,10 +53,10 @@ def DataTypeCast(date_type):
class TestCollectiveRunnerBase(object): class TestCollectiveRunnerBase(object):
def get_model(self, train_prog, startup_prog, col_type): def get_model(self, train_prog, startup_prog, col_type):
raise NotImplementedError( raise NotImplementedError(
"get model should be implemented by child class.") "get model should be implemented by child class."
)
def wait_server_ready(self, endpoints): def wait_server_ready(self, endpoints):
while True: while True:
...@@ -64,13 +64,15 @@ class TestCollectiveRunnerBase(object): ...@@ -64,13 +64,15 @@ class TestCollectiveRunnerBase(object):
not_ready_endpoints = [] not_ready_endpoints = []
for ep in endpoints: for ep in endpoints:
ip_port = ep.split(":") ip_port = ep.split(":")
with closing(socket.socket(socket.AF_INET, with closing(
socket.SOCK_STREAM)) as sock: socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.settimeout(2) sock.settimeout(2)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'): if hasattr(socket, 'SO_REUSEPORT'):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, sock.setsockopt(
1) socket.SOL_SOCKET, socket.SO_REUSEPORT, 1
)
result = sock.connect_ex((ip_port[0], int(ip_port[1]))) result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0: if result != 0:
...@@ -78,44 +80,51 @@ class TestCollectiveRunnerBase(object): ...@@ -78,44 +80,51 @@ class TestCollectiveRunnerBase(object):
not_ready_endpoints.append(ep) not_ready_endpoints.append(ep)
if not all_ok: if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n") sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + sys.stderr.write(
str(not_ready_endpoints) + "\n") "not ready endpoints:" + str(not_ready_endpoints) + "\n"
)
sys.stderr.flush() sys.stderr.flush()
time.sleep(3) time.sleep(3)
else: else:
break break
# endpoints should be ["ip1:port1","ip2:port2"]
#endpoints should be ["ip1:port1","ip2:port2"] def initCommunicator(
self, program, rank, nranks, wait_port, current_endpoint, endpoints
def initCommunicator(self, program, rank, nranks, wait_port, ):
current_endpoint, endpoints):
other_endpoints = endpoints[:] other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint) other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port: if rank == 0 and wait_port:
self.wait_server_ready(other_endpoints) self.wait_server_ready(other_endpoints)
block = program.global_block() block = program.global_block()
cncl_id_var = block.create_var(name=nameGen.generate('cncl_id'), cncl_id_var = block.create_var(
persistable=True, name=nameGen.generate('cncl_id'),
type=core.VarDesc.VarType.RAW) persistable=True,
type=core.VarDesc.VarType.RAW,
block.append_op(type='c_gen_cncl_id', )
inputs={},
outputs={'Out': cncl_id_var}, block.append_op(
attrs={ type='c_gen_cncl_id',
'rank': rank, inputs={},
'endpoint': current_endpoint, outputs={'Out': cncl_id_var},
'other_endpoints': other_endpoints attrs={
}) 'rank': rank,
'endpoint': current_endpoint,
block.append_op(type='c_comm_init', 'other_endpoints': other_endpoints,
inputs={'X': cncl_id_var}, },
outputs={}, )
attrs={
'nranks': nranks, block.append_op(
'rank': rank, type='c_comm_init',
'ring_id': self.global_ring_id inputs={'X': cncl_id_var},
}) outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': self.global_ring_id,
},
)
def run_trainer(self, args): def run_trainer(self, args):
train_prog = fluid.Program() train_prog = fluid.Program()
...@@ -124,8 +133,9 @@ class TestCollectiveRunnerBase(object): ...@@ -124,8 +133,9 @@ class TestCollectiveRunnerBase(object):
rank = args["trainerid"] rank = args["trainerid"]
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
self.initCommunicator(startup_prog, rank, nranks, True, self.initCommunicator(
current_endpoint, endpoints) startup_prog, rank, nranks, True, current_endpoint, endpoints
)
self.rank = rank self.rank = rank
result = self.get_model(train_prog, startup_prog, args["col_type"]) result = self.get_model(train_prog, startup_prog, args["col_type"])
device_id = int(os.getenv("FLAGS_selected_mlus", "0")) device_id = int(os.getenv("FLAGS_selected_mlus", "0"))
...@@ -135,9 +145,9 @@ class TestCollectiveRunnerBase(object): ...@@ -135,9 +145,9 @@ class TestCollectiveRunnerBase(object):
np.random.seed(os.getpid()) np.random.seed(os.getpid())
np_data_type = DataTypeCast(args["data_type"]) np_data_type = DataTypeCast(args["data_type"])
indata = np.random.random((10, 1000)).astype(np_data_type) indata = np.random.random((10, 1000)).astype(np_data_type)
out = exe.run(train_prog, out = exe.run(
feed={'tindata': indata}, train_prog, feed={'tindata': indata}, fetch_list=[result.name]
fetch_list=[result.name]) )
sys.stdout.buffer.write(pickle.dumps(out)) sys.stdout.buffer.write(pickle.dumps(out))
...@@ -160,19 +170,20 @@ from contextlib import closing ...@@ -160,19 +170,20 @@ from contextlib import closing
class TestDistBase(unittest.TestCase): class TestDistBase(unittest.TestCase):
def setUp(self): def setUp(self):
self._port_set = set() self._port_set = set()
self._trainers = 2 self._trainers = 2
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(),
self._find_free_port(),
)
self._python_interp = sys.executable self._python_interp = sys.executable
def _find_free_port(self): def _find_free_port(self):
def __free_port(): def __free_port():
with closing(socket.socket(socket.AF_INET, with closing(
socket.SOCK_STREAM)) as s: socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as s:
s.bind(('', 0)) s.bind(('', 0))
return s.getsockname()[1] return s.getsockname()[1]
...@@ -191,7 +202,7 @@ class TestDistBase(unittest.TestCase): ...@@ -191,7 +202,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ID": "0", "PADDLE_TRAINER_ID": "0",
"PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep "PADDLE_CURRENT_ENDPOINT": w0_ep,
} }
env1 = { env1 = {
...@@ -199,9 +210,9 @@ class TestDistBase(unittest.TestCase): ...@@ -199,9 +210,9 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ID": "1", "PADDLE_TRAINER_ID": "1",
"PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep "PADDLE_CURRENT_ENDPOINT": w1_ep,
} }
#update environment # update environment
env0.update(envs) env0.update(envs)
env1.update(envs) env1.update(envs)
tr_cmd = "%s %s" tr_cmd = "%s %s"
...@@ -210,15 +221,19 @@ class TestDistBase(unittest.TestCase): ...@@ -210,15 +221,19 @@ class TestDistBase(unittest.TestCase):
tr0_pipe = open("/tmp/tr0_err.log", "wb") tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb") tr1_pipe = open("/tmp/tr1_err.log", "wb")
tr0_proc = subprocess.Popen(tr0_cmd.strip().split(), tr0_proc = subprocess.Popen(
stdout=subprocess.PIPE, tr0_cmd.strip().split(),
stderr=tr0_pipe, stdout=subprocess.PIPE,
env=env0) stderr=tr0_pipe,
env=env0,
)
tr1_proc = subprocess.Popen(tr0_cmd.strip().split(), tr1_proc = subprocess.Popen(
stdout=subprocess.PIPE, tr0_cmd.strip().split(),
stderr=tr1_pipe, stdout=subprocess.PIPE,
env=env1) stderr=tr1_pipe,
env=env1,
)
tr0_out, tr0_err = tr0_proc.communicate() tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate()
...@@ -227,15 +242,21 @@ class TestDistBase(unittest.TestCase): ...@@ -227,15 +242,21 @@ class TestDistBase(unittest.TestCase):
# close trainer file # close trainer file
tr0_pipe.close() tr0_pipe.close()
tr1_pipe.close() tr1_pipe.close()
return pickle.loads(tr0_out), pickle.loads( return (
tr1_out), tr0_proc.pid, tr1_proc.pid pickle.loads(tr0_out),
pickle.loads(tr1_out),
def check_with_place(self, tr0_proc.pid,
model_file, tr1_proc.pid,
col_type, )
data_type,
check_error_log=False, def check_with_place(
need_envs={}): self,
model_file,
col_type,
data_type,
check_error_log=False,
need_envs={},
):
required_envs = { required_envs = {
"FLAGS_eager_delete_tensor_gb": "0.0", "FLAGS_eager_delete_tensor_gb": "0.0",
"PATH": os.getenv("PATH"), "PATH": os.getenv("PATH"),
...@@ -251,7 +272,8 @@ class TestDistBase(unittest.TestCase): ...@@ -251,7 +272,8 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_v"] = "3" required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster( tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs) model_file, required_envs
)
np_data_type = DataTypeCast(data_type) np_data_type = DataTypeCast(data_type)
np.random.seed(pid0) np.random.seed(pid0)
input1 = np.random.random((10, 1000)).astype(np_data_type) input1 = np.random.random((10, 1000)).astype(np_data_type)
...@@ -259,63 +281,55 @@ class TestDistBase(unittest.TestCase): ...@@ -259,63 +281,55 @@ class TestDistBase(unittest.TestCase):
input2 = np.random.random((10, 1000)).astype(np_data_type) input2 = np.random.random((10, 1000)).astype(np_data_type)
if col_type == "broadcast": if col_type == "broadcast":
need_result = input2 need_result = input2
np.testing.assert_allclose(tr0_out, need_result) np.testing.assert_allclose(tr0_out[0], need_result)
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
elif col_type == "allreduce_sum": elif col_type == "allreduce_sum":
need_result = input1 + input2 need_result = input1 + input2
np.testing.assert_allclose(tr0_out, np.testing.assert_allclose(
need_result, tr0_out[0], need_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05) np.testing.assert_allclose(
np.testing.assert_allclose(tr1_out, tr1_out[0], need_result, rtol=1e-05, atol=1e-05
need_result, )
rtol=1e-05,
atol=1e-05)
elif col_type == "allreduce_prod": elif col_type == "allreduce_prod":
need_result = input1 * input2 need_result = input1 * input2
np.testing.assert_allclose(tr0_out, np.testing.assert_allclose(
need_result, tr0_out[0], need_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05) np.testing.assert_allclose(
np.testing.assert_allclose(tr1_out, tr1_out[0], need_result, rtol=1e-05, atol=1e-05
need_result, )
rtol=1e-05,
atol=1e-05)
elif col_type == "allreduce_max": elif col_type == "allreduce_max":
need_result = np.maximum(input1, input2) need_result = np.maximum(input1, input2)
np.testing.assert_allclose(tr0_out, np.testing.assert_allclose(
need_result, tr0_out[0], need_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05) np.testing.assert_allclose(
np.testing.assert_allclose(tr1_out, tr1_out[0], need_result, rtol=1e-05, atol=1e-05
need_result, )
rtol=1e-05,
atol=1e-05)
elif col_type == "allreduce_min": elif col_type == "allreduce_min":
need_result = np.minimum(input1, input2) need_result = np.minimum(input1, input2)
np.testing.assert_allclose(tr0_out, np.testing.assert_allclose(
need_result, tr0_out[0], need_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05) np.testing.assert_allclose(
np.testing.assert_allclose(tr1_out, tr1_out[0], need_result, rtol=1e-05, atol=1e-05
need_result, )
rtol=1e-05,
atol=1e-05)
elif col_type == "reduce_sum": elif col_type == "reduce_sum":
need_result = input1 + input2 need_result = input1 + input2
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
elif col_type == "reduce_prod": elif col_type == "reduce_prod":
need_result = input1 * input2 need_result = input1 * input2
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
elif col_type == "reduce_max": elif col_type == "reduce_max":
need_result = np.maximum(input1, input2) need_result = np.maximum(input1, input2)
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
elif col_type == "reduce_min": elif col_type == "reduce_min":
need_result = np.minimum(input1, input2) need_result = np.minimum(input1, input2)
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
elif col_type == "allgather": elif col_type == "allgather":
need_result = np.vstack((input1, input2)) need_result = np.vstack((input1, input2))
np.testing.assert_allclose(tr0_out, need_result) np.testing.assert_allclose(tr0_out[0], need_result)
np.testing.assert_allclose(tr1_out, need_result) np.testing.assert_allclose(tr1_out[0], need_result)
else: else:
pass pass
...@@ -29,26 +29,44 @@ SEED = 2022 ...@@ -29,26 +29,44 @@ SEED = 2022
class TestDropoutOp(OpTest): class TestDropoutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout"
self.set_mlu() self.set_mlu()
self.init_dtype() self.init_dtype()
self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} self.init_inputs_shape()
self.init_attrs()
self.op_type = 'dropout'
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)}
self.attrs = { self.attrs = {
'dropout_prob': 0.0, 'dropout_prob': self.dropout_prob,
'fix_seed': True, 'fix_seed': self.fix_seed,
'is_test': False, 'is_test': self.is_test,
'dropout_implementation': 'upscale_in_train' 'dropout_implementation': self.dropout_implementation,
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('uint8')
} }
out = self.inputs['X'] * (1.0 - self.dropout_prob)
if self.is_test == False:
mask = None
if self.dropout_prob == 0.0:
mask = np.ones(self.shape).astype('uint8')
elif self.dropout_prob == 1.0:
mask = np.zeros(self.shape).astype('uint8')
self.outputs = {'Out': out, 'Mask': mask}
else:
self.outputs = {'Out': out}
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def init_inputs_shape(self):
self.shape = [32, 64]
def init_attrs(self):
self.__class__.no_need_check_grad = False
self.dropout_prob = 0.0
self.fix_seed = True
self.is_test = False
self.dropout_implementation = "upscale_in_train"
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0) self.place = paddle.device.MLUPlace(0)
...@@ -57,84 +75,107 @@ class TestDropoutOp(OpTest): ...@@ -57,84 +75,107 @@ class TestDropoutOp(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
if (
hasattr(self.__class__, "no_need_check_grad")
and self.__class__.no_need_check_grad == True
):
return
self.check_grad_with_place(self.place, ['X'], 'Out') self.check_grad_with_place(self.place, ['X'], 'Out')
class TestDropoutOpInput1d(TestDropoutOp): class TestDropoutOpInput1d(TestDropoutOp):
# change input shape def init_inputs_shape(self):
def setUp(self): self.shape = [2000]
self.op_type = "dropout"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((3, 62)).astype(self.dtype)}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((3, 62)).astype('uint8')
}
class TestDropoutOpInput1d_1(TestDropoutOp):
# the input is 1-D
def setUp(self):
self.op_type = "dropout"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((2000)).astype(self.dtype)}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((2000)).astype('uint8')
}
class TestDropoutOp2(TestDropoutOp): class TestDropoutOp2(TestDropoutOp):
# the dropout_prob is 1.0 def init_inputs_shape(self):
def setUp(self): self.shape = [32, 64]
self.op_type = "dropout"
self.set_mlu() def init_attrs(self):
self.init_dtype() self.dropout_prob = 1.0
self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} self.fix_seed = True
self.attrs = { self.is_test = False
'dropout_prob': 1.0, self.dropout_implementation = "upscale_in_train"
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('uint8')
}
class TestDropoutOp3(TestDropoutOp): class TestDropoutOp3(TestDropoutOp):
# the input dim is 3 def init_inputs_shape(self):
self.shape = [32, 64, 2]
class TestDropoutOp4(TestDropoutOp):
def init_attrs(self):
self.__class__.no_need_check_grad = True
self.dropout_prob = 0.35
self.fix_seed = True
self.is_test = True
self.dropout_implementation = "downgrade_in_infer"
class TestDropoutOp5(TestDropoutOp):
def init_inputs_shape(self):
self.shape = [32, 64, 3]
def init_attrs(self):
self.__class__.no_need_check_grad = True
self.dropout_prob = 0.75
self.fix_seed = True
self.is_test = True
self.dropout_implementation = "downgrade_in_infer"
class TestDropoutOp6(TestDropoutOp):
def init_attrs(self):
self.__class__.no_need_check_grad = True
self.dropout_prob = 0.0
self.fix_seed = True
self.is_test = False
self.dropout_implementation = "downgrade_in_infer"
class TestDropoutOpWithSeed(TestDropoutOp):
# the seed is a Tensor
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.set_mlu() self.set_mlu()
self.init_dtype() self.dtype = np.float32
self.inputs = {'X': np.random.random((32, 64, 2)).astype(self.dtype)} self.inputs = {
"X": np.random.random((32, 64)).astype(self.dtype),
"Seed": np.asarray([125], dtype="int32"),
}
self.attrs = { self.attrs = {
'dropout_prob': 0.0, 'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False, 'is_test': False,
'dropout_implementation': 'upscale_in_train' 'dropout_implementation': 'upscale_in_train',
} }
self.outputs = { self.outputs = {
'Out': self.inputs['X'], 'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('uint8') 'Mask': np.ones((32, 64)).astype('uint8'),
} }
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestDropoutOpFp16(TestDropoutOp):
# float16
def init_dtype(self):
self.dtype = np.float16
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
@skip_check_grad_ci(reason="For inference, check_grad is not required.") @skip_check_grad_ci(reason="For inference, check_grad is not required.")
class TestDropoutOpInference(OpTest): class TestDropoutOpInference(OpTest):
...@@ -148,7 +189,7 @@ class TestDropoutOpInference(OpTest): ...@@ -148,7 +189,7 @@ class TestDropoutOpInference(OpTest):
'dropout_prob': 0.35, 'dropout_prob': 0.35,
'fix_seed': True, 'fix_seed': True,
'is_test': True, 'is_test': True,
'dropout_implementation': 'upscale_in_train' 'dropout_implementation': 'upscale_in_train',
} }
self.outputs = {'Out': self.inputs['X']} self.outputs = {'Out': self.inputs['X']}
...@@ -165,7 +206,6 @@ class TestDropoutOpInference(OpTest): ...@@ -165,7 +206,6 @@ class TestDropoutOpInference(OpTest):
@skip_check_grad_ci(reason="For inference, check_grad is not required.") @skip_check_grad_ci(reason="For inference, check_grad is not required.")
class TestDropoutOpInference2(TestDropoutOpInference): class TestDropoutOpInference2(TestDropoutOpInference):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.set_mlu() self.set_mlu()
...@@ -174,45 +214,12 @@ class TestDropoutOpInference2(TestDropoutOpInference): ...@@ -174,45 +214,12 @@ class TestDropoutOpInference2(TestDropoutOpInference):
self.attrs = { self.attrs = {
'dropout_prob': 0.75, 'dropout_prob': 0.75,
'is_test': True, 'is_test': True,
'dropout_implementation': 'upscale_in_train' 'dropout_implementation': 'upscale_in_train',
} }
self.outputs = {'Out': self.inputs['X']} self.outputs = {'Out': self.inputs['X']}
class TestDropoutOpWithSeed(TestDropoutOp):
# the seed is a Tensor
def setUp(self):
self.op_type = "dropout"
self.set_mlu()
self.init_dtype()
self.inputs = {
"X": np.random.random((32, 64)).astype(self.dtype),
"Seed": np.asarray([125], dtype="int32")
}
self.attrs = {
'dropout_prob': 0.0,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('uint8')
}
class TestDropoutOpFp16(TestDropoutOp):
# float16
def init_dtype(self):
self.dtype = np.float16
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
class TestDropoutAPI(unittest.TestCase): class TestDropoutAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(123) np.random.seed(123)
self.places = [fluid.CPUPlace(), paddle.device.MLUPlace(0)] self.places = [fluid.CPUPlace(), paddle.device.MLUPlace(0)]
...@@ -220,43 +227,44 @@ class TestDropoutAPI(unittest.TestCase): ...@@ -220,43 +227,44 @@ class TestDropoutAPI(unittest.TestCase):
def check_static_result(self, place): def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[40, 40], dtype="float32") input = fluid.data(name="input", shape=[40, 40], dtype="float32")
res1 = paddle.nn.functional.dropout(x=input, res1 = paddle.nn.functional.dropout(
p=0., x=input, p=0.0, training=False, mode='upscale_in_train'
training=False, )
mode='upscale_in_train') res2 = paddle.nn.functional.dropout(
res2 = paddle.nn.functional.dropout(x=input, x=input, p=0.0, axis=0, training=True, mode='upscale_in_train'
p=0., )
axis=0, res3 = paddle.nn.functional.dropout(
training=True, x=input, p=0.0, axis=0, training=False, mode='upscale_in_train'
mode='upscale_in_train') )
res3 = paddle.nn.functional.dropout(x=input, res4 = paddle.nn.functional.dropout(
p=0., x=input,
axis=0, p=0.0,
training=False, axis=[0, 1],
mode='upscale_in_train') training=True,
res4 = paddle.nn.functional.dropout(x=input, mode='upscale_in_train',
p=0., )
axis=[0, 1], res5 = paddle.nn.functional.dropout(
training=True, x=input,
mode='upscale_in_train') p=0.0,
res5 = paddle.nn.functional.dropout(x=input, axis=[0, 1],
p=0., training=False,
axis=[0, 1], mode='upscale_in_train',
training=False, )
mode='upscale_in_train') res6 = paddle.nn.functional.dropout(
res6 = paddle.nn.functional.dropout(x=input, x=input, p=1.0, training=True, mode='upscale_in_train'
p=1., )
training=True,
mode='upscale_in_train')
res7 = paddle.fluid.layers.dropout( res7 = paddle.fluid.layers.dropout(
x=input, x=input,
dropout_prob=0., dropout_prob=0.0,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train',
res8 = paddle.nn.functional.dropout(x=input, )
p=0., res8 = paddle.nn.functional.dropout(
axis=(0, 1), x=input,
training=False, p=0.0,
mode='upscale_in_train') axis=(0, 1),
training=False,
mode='upscale_in_train',
)
in_np = np.random.random([40, 40]).astype("float32") in_np = np.random.random([40, 40]).astype("float32")
res_np = in_np res_np = in_np
...@@ -265,13 +273,17 @@ class TestDropoutAPI(unittest.TestCase): ...@@ -265,13 +273,17 @@ class TestDropoutAPI(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
res_list = [res1, res2, res3, res4, res5, res7, res8] res_list = [res1, res2, res3, res4, res5, res7, res8]
for res in res_list: for res in res_list:
fetches = exe.run(fluid.default_main_program(), fetches = exe.run(
feed={"input": in_np}, fluid.default_main_program(),
fetch_list=[res]) feed={"input": in_np},
fetch_list=[res],
)
np.testing.assert_allclose(fetches[0], res_np) np.testing.assert_allclose(fetches[0], res_np)
fetches2 = exe.run(fluid.default_main_program(), fetches2 = exe.run(
feed={"input": in_np}, fluid.default_main_program(),
fetch_list=[res6]) feed={"input": in_np},
fetch_list=[res6],
)
np.testing.assert_allclose(fetches2[0], res_np2) np.testing.assert_allclose(fetches2[0], res_np2)
def test_static(self): def test_static(self):
......
...@@ -28,12 +28,15 @@ def AffineGrid(theta, grid_shape): ...@@ -28,12 +28,15 @@ def AffineGrid(theta, grid_shape):
n = grid_shape[0] n = grid_shape[0]
h = grid_shape[1] h = grid_shape[1]
w = grid_shape[2] w = grid_shape[2]
h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w, h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[
axis=0).T[:, :, np.newaxis] :, :, np.newaxis
w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h, ]
axis=0)[:, :, np.newaxis] w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[
grid = np.concatenate([w_idx, h_idx, np.ones([h, w, 1])], :, :, np.newaxis
axis=2) # h * w * 3 ]
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2
) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3 grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3
ret = np.zeros([n, h * w, 2]) ret = np.zeros([n, h * w, 2])
...@@ -53,13 +56,17 @@ def getGridPointValue(data, x, y): ...@@ -53,13 +56,17 @@ def getGridPointValue(data, x, y):
out_H = x.shape[1] out_H = x.shape[1]
out_W = x.shape[2] out_W = x.shape[2]
#out = np.zeros(data_shape, dtype='float32') # out = np.zeros(data_shape, dtype='float32')
out = np.zeros([N, C, out_H, out_W], dtype='float32') out = np.zeros([N, C, out_H, out_W], dtype='float32')
for i in range(N): for i in range(N):
for j in range(out_H): for j in range(out_H):
for k in range(out_W): for k in range(out_W):
if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[ if (
i, j, k] < 0 or x[i, j, k] > in_W - 1: y[i, j, k] < 0
or y[i, j, k] > in_H - 1
or x[i, j, k] < 0
or x[i, j, k] > in_W - 1
):
out[i, :, j, k] = 0 out[i, :, j, k] = 0
else: else:
out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]
...@@ -75,27 +82,28 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): ...@@ -75,27 +82,28 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
if align_corners: if align_corners:
grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * max_val) grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * max_val)
else: else:
grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * grid_slice = (
(max_val + 1)) - 0.5 0.5 * ((grid_slice.astype('float32') + 1.0) * (max_val + 1)) - 0.5
)
if padding_mode == "border": if padding_mode == "border":
grid_slice = clip(grid_slice, 0, max_val) grid_slice = clip(grid_slice, 0, max_val)
elif padding_mode == "reflection": elif padding_mode == "reflection":
double_range = 2 * max_val if align_corners else (max_val + 1) * 2 double_range = 2 * max_val if align_corners else (max_val + 1) * 2
grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + grid_abs = (
0.5) np.abs(grid_slice) if align_corners else np.abs(grid_slice + 0.5)
)
extra = grid_abs - np.floor(grid_abs / double_range) * double_range extra = grid_abs - np.floor(grid_abs / double_range) * double_range
grid_slice = np.minimum(extra, double_range - extra) grid_slice = np.minimum(extra, double_range - extra)
grid_slice = grid_slice if align_corners else clip( grid_slice = (
grid_slice - 0.5, 0, max_val) grid_slice if align_corners else clip(grid_slice - 0.5, 0, max_val)
)
return grid_slice return grid_slice
def GridSampler(data, def GridSampler(
grid, data, grid, align_corners=True, mode="bilinear", padding_mode="zeros"
align_corners=True, ):
mode="bilinear",
padding_mode="zeros"):
dims = data.shape dims = data.shape
N = dims[0] N = dims[0]
in_C = dims[1] in_C = dims[1]
...@@ -119,14 +127,18 @@ def GridSampler(data, ...@@ -119,14 +127,18 @@ def GridSampler(data,
y0 = np.floor(y).astype('int32') y0 = np.floor(y).astype('int32')
y1 = y0 + 1 y1 = y0 + 1
wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), wa = np.tile(
(1, in_C, 1, 1)) ((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1)
wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), )
(1, in_C, 1, 1)) wb = np.tile(
wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), ((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1)
(1, in_C, 1, 1)) )
wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), wc = np.tile(
(1, in_C, 1, 1)) ((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1)
)
wd = np.tile(
((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), (1, in_C, 1, 1)
)
va = getGridPointValue(data, x0, y0) va = getGridPointValue(data, x0, y0)
vb = getGridPointValue(data, x0, y1) vb = getGridPointValue(data, x0, y1)
...@@ -142,7 +154,6 @@ def GridSampler(data, ...@@ -142,7 +154,6 @@ def GridSampler(data,
class TestGridSamplerOp(OpTest): class TestGridSamplerOp(OpTest):
def setUp(self): def setUp(self):
self.place = paddle.device.MLUPlace(0) self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -166,12 +177,12 @@ class TestGridSamplerOp(OpTest): ...@@ -166,12 +177,12 @@ class TestGridSamplerOp(OpTest):
'use_cudnn': False, 'use_cudnn': False,
"align_corners": self.align_corners, "align_corners": self.align_corners,
"padding_mode": self.padding_mode, "padding_mode": self.padding_mode,
"mode": self.mode "mode": self.mode,
} }
self.outputs = { self.outputs = {
'Output': 'Output': GridSampler(
GridSampler(x, grid, self.align_corners, self.mode, x, grid, self.align_corners, self.mode, self.padding_mode
self.padding_mode) )
} }
def test_check_output(self): def test_check_output(self):
...@@ -186,20 +197,17 @@ class TestGridSamplerOp(OpTest): ...@@ -186,20 +197,17 @@ class TestGridSamplerOp(OpTest):
self.mode = "bilinear" self.mode = "bilinear"
# TODO(fwg): Test this case when cnnl support align_corners = True. class Case1(TestGridSamplerOp):
# class Case1(TestGridSamplerOp): def initTestCase(self):
# self.x_shape = (2, 3, 5, 6)
# def initTestCase(self): self.grid_shape = (2, 8, 9, 2)
# self.x_shape = (2, 3, 5, 6) self.theta_shape = (2, 2, 3)
# self.grid_shape = (2, 8, 9, 2) self.align_corners = True
# self.theta_shape = (2, 2, 3) self.padding_mode = "zeros"
# self.align_corners = True self.mode = "bilinear"
# self.padding_mode = "zeros"
# self.mode = "bilinear"
class LargeInputCase(TestGridSamplerOp): class LargeInputCase(TestGridSamplerOp):
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 3, 128, 128) self.x_shape = (2, 3, 128, 128)
self.grid_shape = (2, 130, 130, 2) self.grid_shape = (2, 130, 130, 2)
...@@ -209,16 +217,15 @@ class LargeInputCase(TestGridSamplerOp): ...@@ -209,16 +217,15 @@ class LargeInputCase(TestGridSamplerOp):
self.mode = "bilinear" self.mode = "bilinear"
# TODO(fwg): Test this case when cnnl support align_corners = True. class Case2(LargeInputCase):
# class Case2(LargeInputCase): def initTestCase(self):
# self.x_shape = (2, 3, 128, 128)
# def initTestCase(self): self.grid_shape = (2, 130, 130, 2)
# self.x_shape = (2, 3, 128, 128) self.theta_shape = (2, 2, 3)
# self.grid_shape = (2, 130, 130, 2) self.align_corners = True
# self.theta_shape = (2, 2, 3) self.padding_mode = "zeros"
# self.align_corners = True self.mode = "bilinear"
# self.padding_mode = "zeros"
# self.mode = "bilinear"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
paddle.enable_static()
def huber_loss_forward(val, delta):
abs_val = abs(val)
if abs_val <= delta:
return 0.5 * val * val
else:
return delta * (abs_val - 0.5 * delta)
class TestHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'huber_loss'
self.set_mlu()
self.python_api = paddle.fluid.layers.huber_loss
self.python_out_sig = ["Out"]
self.delta = 1.0
self.init_input()
shape = self.set_shape()
residual = self.inputs['Y'] - self.inputs['X']
loss = np.vectorize(huber_loss_forward)(residual, self.delta).astype(
'float32'
)
self.attrs = {'delta': self.delta}
self.outputs = {'Residual': residual, 'Out': loss.reshape(shape)}
def init_input(self):
shape = self.set_shape()
self.inputs = {
'X': np.random.uniform(0, 1.0, shape).astype('float32'),
'Y': np.random.uniform(0, 1.0, shape).astype('float32'),
}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def set_shape(self):
return (100, 1)
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad_with_place(
self.place,
['Y'],
'Out',
max_relative_error=0.008,
no_grad_set=set("residual"),
)
def test_check_grad_ingore_y(self):
self.check_grad_with_place(
self.place,
['X'],
'Out',
max_relative_error=0.008,
no_grad_set=set('residual'),
)
def TestHuberLossOp1(TestHuberLossOp):
def set_shape(self):
return 64
def TestHuberLossOp2(TestHuberLossOp):
def set_shape(self):
return (6, 6)
def TestHuberLossOp3(TestHuberLossOp):
def set_shape(self):
return (6, 6, 1)
class TestHuberLossOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input and label must be Variable
xw = np.random.random((6, 6)).astype("float32")
xr = fluid.data(name='xr', shape=[None, 6], dtype="float32")
lw = np.random.random((6, 6)).astype("float32")
lr = fluid.data(name='lr', shape=[None, 6], dtype="float32")
delta = 1.0
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw, delta)
self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr, delta)
# the dtype of input and label must be float32 or float64
xw2 = fluid.data(name='xw2', shape=[None, 6], dtype="int32")
lw2 = fluid.data(name='lw2', shape=[None, 6], dtype="int32")
self.assertRaises(
TypeError, fluid.layers.huber_loss, xw2, lr, delta
)
self.assertRaises(
TypeError, fluid.layers.huber_loss, xr, lw2, delta
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle
import numpy as np
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.framework import in_dygraph_mode
def run_adam_op(
params,
grads,
lrs,
moment1s,
moment2s,
beta1_pows,
beta2_pows,
master_params,
epsilon,
beta1,
beta2,
place,
multi_precision=False,
use_merged=False,
):
assert len(params) == len(grads)
assert len(params) == len(lrs)
assert len(params) == len(moment1s)
assert len(params) == len(moment2s)
assert len(params) == len(beta1_pows)
assert len(params) == len(beta1_pows)
assert len(params) == len(master_params)
paddle.disable_static()
# paddle.set_device(place)
param_vars = [paddle.fluid.dygraph.to_variable(p) for p in params]
grad_vars = [paddle.fluid.dygraph.to_variable(g) for g in grads]
lr_vars = [paddle.fluid.dygraph.to_variable(l) for l in lrs]
moment1_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment1s]
moment2_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment2s]
beta1_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta1_pows]
beta2_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta2_pows]
master_param_vars = [
paddle.fluid.dygraph.to_variable(m_p) for m_p in master_params
]
if not use_merged:
for i in range(len(param_vars)):
_, _, _, _, _, _ = _legacy_C_ops.adam(
param_vars[i],
grad_vars[i],
lr_vars[i],
moment1_vars[i],
moment2_vars[i],
beta1_pow_vars[i],
beta2_pow_vars[i],
master_param_vars[i],
param_vars[i],
moment1_vars[i],
moment2_vars[i],
beta1_pow_vars[i],
beta2_pow_vars[i],
master_param_vars[i],
'epsilon',
epsilon,
'beta1',
beta1,
'beta2',
beta2,
'multi_precision',
multi_precision,
)
else:
if in_dygraph_mode():
_, _, _, _, _, _ = _C_ops.merged_adam_(
param_vars,
grad_vars,
lr_vars,
moment1_vars,
moment2_vars,
beta1_pow_vars,
beta2_pow_vars,
master_param_vars,
beta1,
beta2,
epsilon,
multi_precision,
False,
)
else:
_, _, _, _, _, _ = _legacy_C_ops.merged_adam(
param_vars,
grad_vars,
lr_vars,
moment1_vars,
moment2_vars,
beta1_pow_vars,
beta2_pow_vars,
master_param_vars,
param_vars,
moment1_vars,
moment2_vars,
beta1_pow_vars,
beta2_pow_vars,
master_param_vars,
'epsilon',
epsilon,
'beta1',
beta1,
'beta2',
beta2,
'multi_precision',
multi_precision,
)
outputs = {
'ParamOut': param_vars,
'Moment1Out': moment1_vars,
'Moment2Out': moment2_vars,
'Beta1PowOut': beta1_pow_vars,
'Beta2PowOut': beta2_pow_vars,
'MasterParamOut': master_param_vars,
}
return outputs
class TestMergedAdam(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
# dtype = np.float16 if multi_precision and place == 'mlu' else np.float32
dtype = np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
moment1s = self.gen_rand_data(shapes, mp_dtype)
moment2s = self.gen_rand_data(shapes, mp_dtype)
beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
master_params = [p.astype(mp_dtype) for p in params]
return (
params,
grads,
lrs,
moment1s,
moment2s,
beta1_pows,
beta2_pows,
master_params,
)
def check_with_place(self, place, multi_precision):
(
params,
grads,
lrs,
moment1s,
moment2s,
beta1_pows,
beta2_pows,
master_params,
) = self.prepare_data(self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
return run_adam_op(
params=params,
grads=grads,
lrs=lrs,
moment1s=moment1s,
moment2s=moment2s,
beta1_pows=beta1_pows,
beta2_pows=beta2_pows,
master_params=master_params,
epsilon=0.9,
beta1=0.9,
beta2=0.99,
place=place,
multi_precision=multi_precision,
use_merged=use_merged,
)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for key in outs1.keys():
value1 = outs1[key]
value2 = outs2[key]
for i in range(len(value1)):
if place == 'mlu':
np.testing.assert_array_equal(value1[i], value2[i])
else:
np.testing.assert_allclose(
value1[i], value2[i], rtol=1e-05, atol=1e-07
)
def test_main(self):
for multi_precision in [False, True]:
self.check_with_place(self.place, multi_precision)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append('..')
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle
import math
paddle.enable_static()
class TestMLUPriorBox(OpTest):
def setUp(self):
self.op_type = "prior_box"
self.set_mlu()
self.init_dtype()
self.set_data()
def test_check_output(self):
self.check_output_with_place(self.place)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image}
self.attrs = {
'min_sizes': self.min_sizes,
'aspect_ratios': self.aspect_ratios,
'variances': self.variances,
'flip': self.flip,
'clip': self.clip,
'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order,
'step_w': self.step_w,
'step_h': self.step_h,
'offset': self.offset,
}
if len(self.max_sizes) > 0:
self.attrs['max_sizes'] = self.max_sizes
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
def set_max_sizes(self):
max_sizes = [5, 10]
self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = True
def init_test_params(self):
self.layer_w = 32
self.layer_h = 32
self.image_w = 40
self.image_h = 40
self.step_w = float(self.image_w) / float(self.layer_w)
self.step_h = float(self.image_h) / float(self.layer_h)
self.input_channels = 2
self.image_channels = 3
self.batch_size = 10
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
self.set_max_sizes()
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.set_min_max_aspect_ratios_order()
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float64
).flatten()
self.variances = [0.1, 0.1, 0.2, 0.2]
self.variances = np.array(self.variances, dtype=np.float64).flatten()
self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 0:
self.num_priors += len(self.max_sizes)
self.offset = 0.5
def init_test_input(self):
self.image = np.random.random(
(self.batch_size, self.image_channels, self.image_w, self.image_h)
).astype('float32')
self.input = np.random.random(
(self.batch_size, self.input_channels, self.layer_w, self.layer_h)
).astype('float32')
def init_test_output(self):
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
out_boxes = np.zeros(out_dim).astype('float32')
out_var = np.zeros(out_dim).astype('float32')
idx = 0
for h in range(self.layer_h):
for w in range(self.layer_w):
c_x = (w + self.offset) * self.step_w
c_y = (h + self.offset) * self.step_h
idx = 0
for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s]
if not self.min_max_aspect_ratios_order:
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h,
]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h,
]
idx += 1
else:
c_w = c_h = min_size / 2.0
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h,
]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h,
]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if abs(ar - 1.0) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h,
]
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
# set the variance.
out_var = np.tile(
self.variances, (self.layer_h, self.layer_w, self.num_priors, 1)
)
self.out_boxes = out_boxes.astype('float32')
self.out_var = out_var.astype('float32')
class TestMLUPriorBoxWithoutMaxSize(TestMLUPriorBox):
def set_max_sizes(self):
self.max_sizes = []
class TestMLUPriorBoxWithoutSpecifiedOutOrder(TestMLUPriorBox):
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = False
if __name__ == '__main__':
unittest.main()
...@@ -26,7 +26,6 @@ paddle.enable_static() ...@@ -26,7 +26,6 @@ paddle.enable_static()
class TestMLUReduceSumOp(OpTest): class TestMLUReduceSumOp(OpTest):
def setUp(self): def setUp(self):
self.init_op_type() self.init_op_type()
self.initTestCase() self.initTestCase()
...@@ -34,16 +33,16 @@ class TestMLUReduceSumOp(OpTest): ...@@ -34,16 +33,16 @@ class TestMLUReduceSumOp(OpTest):
self.attrs = { self.attrs = {
'dim': self.axis, 'dim': self.axis,
'keep_dim': self.keep_dim, 'keep_dim': self.keep_dim,
'reduce_all': self.reduce_all 'reduce_all': self.reduce_all,
} }
self.inputs = {'X': np.random.random(self.shape).astype("float32")} self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']: if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
else: else:
self.outputs = { self.outputs = {
'Out': 'Out': self.inputs['X'].sum(
self.inputs['X'].sum(axis=self.axis, axis=self.axis, keepdims=self.attrs['keep_dim']
keepdims=self.attrs['keep_dim']) )
} }
def set_mlu(self): def set_mlu(self):
...@@ -64,100 +63,92 @@ class TestMLUReduceSumOp(OpTest): ...@@ -64,100 +63,92 @@ class TestMLUReduceSumOp(OpTest):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 10) self.shape = (5, 6, 10)
self.axis = (0, ) self.axis = (0,)
class TestSumOp5D(TestMLUReduceSumOp): class TestSumOp5D(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (1, 2, 5, 6, 10) self.shape = (1, 2, 5, 6, 10)
self.axis = (0, ) self.axis = (0,)
class TestSumOp6D(TestMLUReduceSumOp): class TestSumOp6D(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (1, 1, 2, 5, 6, 10) self.shape = (1, 1, 2, 5, 6, 10)
self.axis = (0, ) self.axis = (0,)
class TestSumOp8D(TestMLUReduceSumOp): class TestSumOp8D(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (1, 3, 1, 2, 1, 4, 3, 10) self.shape = (1, 3, 1, 2, 1, 4, 3, 10)
self.axis = (0, 3) self.axis = (0, 3)
class Test1DReduce(TestMLUReduceSumOp): class Test1DReduce(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = 120 self.shape = 120
self.axis = (0, ) self.axis = (0,)
class Test2DReduce0(TestMLUReduceSumOp): class Test2DReduce0(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (20, 10) self.shape = (20, 10)
self.axis = (0, ) self.axis = (0,)
class Test2DReduce1(TestMLUReduceSumOp): class Test2DReduce1(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (20, 10) self.shape = (20, 10)
self.axis = (1, ) self.axis = (1,)
class Test3DReduce0(TestMLUReduceSumOp): class Test3DReduce0(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 7) self.shape = (5, 6, 7)
self.axis = (1, ) self.axis = (1,)
class Test3DReduce1(TestMLUReduceSumOp): class Test3DReduce1(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 7) self.shape = (5, 6, 7)
self.axis = (2, ) self.axis = (2,)
class Test3DReduce2(TestMLUReduceSumOp): class Test3DReduce2(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 7) self.shape = (5, 6, 7)
self.axis = (-2, ) self.axis = (-2,)
class Test3DReduce3(TestMLUReduceSumOp): class Test3DReduce3(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 7) self.shape = (5, 6, 7)
self.axis = (1, 2) self.axis = (1, 2)
class TestKeepDimReduce(TestMLUReduceSumOp): class TestKeepDimReduce(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 10) self.shape = (5, 6, 10)
self.axis = (1, ) self.axis = (1,)
self.keep_dim = True self.keep_dim = True
class TestKeepDim8DReduce(TestMLUReduceSumOp): class TestKeepDim8DReduce(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (2, 5, 3, 2, 2, 3, 4, 2) self.shape = (2, 5, 3, 2, 2, 3, 4, 2)
self.axis = (3, 4, 5) self.axis = (3, 4, 5)
self.keep_dim = True self.keep_dim = True
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.03
)
class TestReduceAll(TestMLUReduceSumOp): class TestReduceAll(TestMLUReduceSumOp):
def initTestCase(self): def initTestCase(self):
self.shape = (5, 6, 2, 10) self.shape = (5, 6, 2, 10)
self.axis = (0, ) self.axis = (0,)
self.reduce_all = True self.reduce_all = True
......
...@@ -31,7 +31,6 @@ paddle.enable_static() ...@@ -31,7 +31,6 @@ paddle.enable_static()
# Situation 1: starts(list, no tensor), ends(list, no tensor) # Situation 1: starts(list, no tensor), ends(list, no tensor)
# 1.1 without attr(decrease) # 1.1 without attr(decrease)
class TestSliceOp(OpTest): class TestSliceOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.set_mlu() self.set_mlu()
...@@ -42,7 +41,7 @@ class TestSliceOp(OpTest): ...@@ -42,7 +41,7 @@ class TestSliceOp(OpTest):
'axes': self.axes, 'axes': self.axes,
'starts': self.starts, 'starts': self.starts,
'ends': self.ends, 'ends': self.ends,
'infer_flags': self.infer_flags 'infer_flags': self.infer_flags,
} }
def config(self): def config(self):
...@@ -57,9 +56,9 @@ class TestSliceOp(OpTest): ...@@ -57,9 +56,9 @@ class TestSliceOp(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -67,7 +66,6 @@ class TestSliceOp(OpTest): ...@@ -67,7 +66,6 @@ class TestSliceOp(OpTest):
class TestCase1(TestSliceOp): class TestCase1(TestSliceOp):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2] self.starts = [-3, 0, 2]
...@@ -78,7 +76,6 @@ class TestCase1(TestSliceOp): ...@@ -78,7 +76,6 @@ class TestCase1(TestSliceOp):
class TestCase2(TestSliceOp): class TestCase2(TestSliceOp):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-3, 0, 2] self.starts = [-3, 0, 2]
...@@ -90,7 +87,6 @@ class TestCase2(TestSliceOp): ...@@ -90,7 +87,6 @@ class TestCase2(TestSliceOp):
# 1.2 with attr(decrease) # 1.2 with attr(decrease)
class TestSliceOp_decs_dim(OpTest): class TestSliceOp_decs_dim(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.set_mlu() self.set_mlu()
...@@ -118,9 +114,9 @@ class TestSliceOp_decs_dim(OpTest): ...@@ -118,9 +114,9 @@ class TestSliceOp_decs_dim(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -128,7 +124,6 @@ class TestSliceOp_decs_dim(OpTest): ...@@ -128,7 +124,6 @@ class TestSliceOp_decs_dim(OpTest):
class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [1, 0, 2] self.starts = [1, 0, 2]
...@@ -140,7 +135,6 @@ class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): ...@@ -140,7 +135,6 @@ class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim):
class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1, 0, 2] self.starts = [-1, 0, 2]
...@@ -152,7 +146,6 @@ class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim): ...@@ -152,7 +146,6 @@ class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim):
class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 7]).astype("float32") self.input = np.random.random([3, 4, 5, 7]).astype("float32")
self.starts = [0, 1, 2, 3] self.starts = [0, 1, 2, 3]
...@@ -164,7 +157,6 @@ class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim): ...@@ -164,7 +157,6 @@ class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim):
class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1] self.starts = [-1]
...@@ -176,7 +168,6 @@ class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): ...@@ -176,7 +168,6 @@ class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim):
class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [0, 1, 2, 3] self.starts = [0, 1, 2, 3]
...@@ -190,7 +181,6 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): ...@@ -190,7 +181,6 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
# Situation 2: starts(list, have tensor), ends(list, no tensor) # Situation 2: starts(list, have tensor), ends(list, no tensor)
# without attr(decrease) # without attr(decrease)
class TestSliceOp_starts_ListTensor(OpTest): class TestSliceOp_starts_ListTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.set_mlu() self.set_mlu()
...@@ -198,8 +188,9 @@ class TestSliceOp_starts_ListTensor(OpTest): ...@@ -198,8 +188,9 @@ class TestSliceOp_starts_ListTensor(OpTest):
starts_tensor = [] starts_tensor = []
for index, ele in enumerate(self.starts): for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones( starts_tensor.append(
(1)).astype('int64') * ele)) ("x" + str(index), np.ones((1)).astype('int64') * ele)
)
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
...@@ -207,7 +198,7 @@ class TestSliceOp_starts_ListTensor(OpTest): ...@@ -207,7 +198,7 @@ class TestSliceOp_starts_ListTensor(OpTest):
'axes': self.axes, 'axes': self.axes,
'starts': self.starts_infer, 'starts': self.starts_infer,
'ends': self.ends, 'ends': self.ends,
'infer_flags': self.infer_flags 'infer_flags': self.infer_flags,
} }
def config(self): def config(self):
...@@ -224,9 +215,9 @@ class TestSliceOp_starts_ListTensor(OpTest): ...@@ -224,9 +215,9 @@ class TestSliceOp_starts_ListTensor(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -236,7 +227,6 @@ class TestSliceOp_starts_ListTensor(OpTest): ...@@ -236,7 +227,6 @@ class TestSliceOp_starts_ListTensor(OpTest):
# Situation 2: starts(list, have tensor), ends(list, no tensor) # Situation 2: starts(list, have tensor), ends(list, no tensor)
# with attr(decrease) # with attr(decrease)
class TestSliceOp_decs_dim_starts_ListTensor(OpTest): class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.set_mlu() self.set_mlu()
...@@ -244,8 +234,9 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): ...@@ -244,8 +234,9 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
starts_tensor = [] starts_tensor = []
for index, ele in enumerate(self.starts): for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones( starts_tensor.append(
(1)).astype('int32') * ele)) ("x" + str(index), np.ones((1)).astype('int32') * ele)
)
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
...@@ -273,9 +264,9 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): ...@@ -273,9 +264,9 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -283,8 +274,8 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): ...@@ -283,8 +274,8 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest):
class TestSliceOp_decs_dim_5_starts_ListTensor( class TestSliceOp_decs_dim_5_starts_ListTensor(
TestSliceOp_decs_dim_starts_ListTensor): TestSliceOp_decs_dim_starts_ListTensor
):
def config(self): def config(self):
self.input = np.random.random([3, 4, 5, 6]).astype("float32") self.input = np.random.random([3, 4, 5, 6]).astype("float32")
self.starts = [-1] self.starts = [-1]
...@@ -300,7 +291,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor( ...@@ -300,7 +291,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor(
# Situation 3: starts(tensor), ends(list, no tensor) # Situation 3: starts(tensor), ends(list, no tensor)
# with attr(decrease) # with attr(decrease)
class TestSliceOp_decs_dim_starts_OneTensor(OpTest): class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -308,7 +298,7 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): ...@@ -308,7 +298,7 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
self.config() self.config()
self.inputs = { self.inputs = {
'Input': self.input, 'Input': self.input,
"StartsTensor": np.array(self.starts, dtype="int32") "StartsTensor": np.array(self.starts, dtype="int32"),
} }
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
self.attrs = { self.attrs = {
...@@ -332,15 +322,14 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): ...@@ -332,15 +322,14 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
# Situation 4: starts(tensor), ends(tensor) # Situation 4: starts(tensor), ends(tensor)
# without attr(decrease) # without attr(decrease)
class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -350,14 +339,14 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): ...@@ -350,14 +339,14 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
self.inputs = { self.inputs = {
'Input': self.input, 'Input': self.input,
"StartsTensor": np.array(self.starts, dtype="int64"), "StartsTensor": np.array(self.starts, dtype="int64"),
"EndsTensor": np.array(self.ends, dtype="int32") "EndsTensor": np.array(self.ends, dtype="int32"),
} }
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
self.attrs = { self.attrs = {
'axes': self.axes, 'axes': self.axes,
#'starts': self.starts, #'starts': self.starts,
#'ends': self.ends_infer, #'ends': self.ends_infer,
'infer_flags': self.infer_flags 'infer_flags': self.infer_flags,
} }
def config(self): def config(self):
...@@ -372,15 +361,14 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): ...@@ -372,15 +361,14 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
# Situation 5: starts(tensor), ends(tensor) # Situation 5: starts(tensor), ends(tensor)
# with attr(decrease) # with attr(decrease)
class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -389,7 +377,7 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): ...@@ -389,7 +377,7 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
self.inputs = { self.inputs = {
'Input': self.input, 'Input': self.input,
"StartsTensor": np.array(self.starts, dtype="int32"), "StartsTensor": np.array(self.starts, dtype="int32"),
"EndsTensor": np.array(self.ends, dtype="int32") "EndsTensor": np.array(self.ends, dtype="int32"),
} }
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
self.attrs = { self.attrs = {
...@@ -413,15 +401,14 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): ...@@ -413,15 +401,14 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
# Situation 6: starts(tensor), ends(list, have tensor) # Situation 6: starts(tensor), ends(list, have tensor)
# without attr(decrease) # without attr(decrease)
class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -430,20 +417,21 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): ...@@ -430,20 +417,21 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
ends_tensor = [] ends_tensor = []
for index, ele in enumerate(self.ends): for index, ele in enumerate(self.ends):
ends_tensor.append(("y" + str(index), np.ones( ends_tensor.append(
(1)).astype('int32') * ele)) ("y" + str(index), np.ones((1)).astype('int32') * ele)
)
self.inputs = { self.inputs = {
'Input': self.input, 'Input': self.input,
"StartsTensor": np.array(self.starts, dtype="int32"), "StartsTensor": np.array(self.starts, dtype="int32"),
'EndsTensorList': ends_tensor 'EndsTensorList': ends_tensor,
} }
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
self.attrs = { self.attrs = {
'axes': self.axes, 'axes': self.axes,
#'starts': self.starts, #'starts': self.starts,
'ends': self.ends_infer, 'ends': self.ends_infer,
'infer_flags': self.infer_flags 'infer_flags': self.infer_flags,
} }
def config(self): def config(self):
...@@ -460,14 +448,13 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): ...@@ -460,14 +448,13 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
# Test float16 # Test float16
class TestFP16(OpTest): class TestFP16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -479,7 +466,7 @@ class TestFP16(OpTest): ...@@ -479,7 +466,7 @@ class TestFP16(OpTest):
'axes': self.axes, 'axes': self.axes,
'starts': self.starts, 'starts': self.starts,
'ends': self.ends, 'ends': self.ends,
'infer_flags': self.infer_flags 'infer_flags': self.infer_flags,
} }
def config(self): def config(self):
...@@ -495,13 +482,12 @@ class TestFP16(OpTest): ...@@ -495,13 +482,12 @@ class TestFP16(OpTest):
self.check_output_with_place(self.place, atol=1e-5) self.check_output_with_place(self.place, atol=1e-5)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place, ['Input'], 'Out', max_relative_error=0.006
max_relative_error=0.006) )
class TestFP16_2(OpTest): class TestFP16_2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "slice" self.op_type = "slice"
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -513,7 +499,7 @@ class TestFP16_2(OpTest): ...@@ -513,7 +499,7 @@ class TestFP16_2(OpTest):
'axes': self.axes, 'axes': self.axes,
'starts': self.starts, 'starts': self.starts,
'ends': self.ends, 'ends': self.ends,
'infer_flags': self.infer_flags 'infer_flags': self.infer_flags,
} }
def config(self): def config(self):
...@@ -529,24 +515,28 @@ class TestFP16_2(OpTest): ...@@ -529,24 +515,28 @@ class TestFP16_2(OpTest):
self.check_output_with_place(self.place, atol=1e-5) self.check_output_with_place(self.place, atol=1e-5)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['Input'], self.check_grad_with_place(
'Out', self.place,
max_relative_error=0.006, ['Input'],
numeric_grad_delta=0.5) 'Out',
max_relative_error=0.006,
numeric_grad_delta=0.5,
)
class TestSliceApiWithTensor(unittest.TestCase): class TestSliceApiWithTensor(unittest.TestCase):
def test_starts_ends_is_tensor(self): def test_starts_ends_is_tensor(self):
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
a = paddle.rand(shape=[4, 5, 6], dtype='float32') a = paddle.rand(shape=[4, 5, 6], dtype='float32')
axes = [0, 1, 2] axes = [0, 1, 2]
starts = [-3, 0, 2] starts = [-3, 0, 2]
ends = [3, 2, 4] ends = [3, 2, 4]
a_1 = paddle.slice(a, a_1 = paddle.slice(
axes=axes, a,
starts=paddle.to_tensor(starts, dtype='int32'), axes=axes,
ends=paddle.to_tensor(ends, dtype='int32')) starts=paddle.to_tensor(starts, dtype='int32'),
ends=paddle.to_tensor(ends, dtype='int32'),
)
a_2 = paddle.slice(a, axes=axes, starts=starts, ends=ends) a_2 = paddle.slice(a, axes=axes, starts=starts, ends=ends)
np.testing.assert_allclose(a_1.numpy(), a_2.numpy()) np.testing.assert_allclose(a_1.numpy(), a_2.numpy())
...@@ -569,24 +559,22 @@ class TestSliceApiWithTensor(unittest.TestCase): ...@@ -569,24 +559,22 @@ class TestSliceApiWithTensor(unittest.TestCase):
class TestImperativeVarBaseGetItem(unittest.TestCase): class TestImperativeVarBaseGetItem(unittest.TestCase):
def test_getitem_with_long(self): def test_getitem_with_long(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32') data = np.random.random((2, 80, 16128)).astype('float32')
var = fluid.dygraph.to_variable(data) var = fluid.dygraph.to_variable(data)
sliced = var[:, 10:, :var.shape[1]] # var.shape[1] is 80L here sliced = var[:, 10:, : var.shape[1]] # var.shape[1] is 80L here
self.assertEqual(sliced.shape, [2, 70, 80]) self.assertEqual(sliced.shape, [2, 70, 80])
sliced = var[:, var.shape[0]:, var.shape[0]:var.shape[1]] sliced = var[:, var.shape[0] :, var.shape[0] : var.shape[1]]
self.assertEqual(sliced.shape, [2, 78, 78]) self.assertEqual(sliced.shape, [2, 78, 78])
def test_getitem_with_float(self): def test_getitem_with_float(self):
def test_float_in_slice_item(): def test_float_in_slice_item():
with fluid.dygraph.guard(): with fluid.dygraph.guard():
data = np.random.random((2, 80, 16128)).astype('float32') data = np.random.random((2, 80, 16128)).astype('float32')
var = fluid.dygraph.to_variable(data) var = fluid.dygraph.to_variable(data)
sliced = var[:, 1.1:, :var.shape[1]] sliced = var[:, 1.1:, : var.shape[1]]
self.assertRaises(Exception, test_float_in_slice_item) self.assertRaises(Exception, test_float_in_slice_item)
...@@ -600,15 +588,6 @@ class TestImperativeVarBaseGetItem(unittest.TestCase): ...@@ -600,15 +588,6 @@ class TestImperativeVarBaseGetItem(unittest.TestCase):
class TestInferShape(unittest.TestCase): class TestInferShape(unittest.TestCase):
def test(self):
x = paddle.ones(shape=[3, 4, 5])
x.desc.set_shape([3, -1, 5])
self.assertEqual(x.shape, (3, -1, 5))
out0 = paddle.slice(x, axes=[1], starts=[0], ends=[3])
self.assertEqual(out0.shape, (3, 3, 5))
def test_axis_less_than_zero(self): def test_axis_less_than_zero(self):
# Using paddle.disable_static will make other unittests fail. # Using paddle.disable_static will make other unittests fail.
...@@ -616,13 +595,18 @@ class TestInferShape(unittest.TestCase): ...@@ -616,13 +595,18 @@ class TestInferShape(unittest.TestCase):
x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4]) x_arr = np.arange(0, 24, dtype=np.float32).reshape([2, 3, 4])
x = paddle.to_tensor(x_arr) x = paddle.to_tensor(x_arr)
pp_slice = paddle.slice(x, [ pp_slice = paddle.slice(
100, x,
], [0], [1]) [
100,
],
[0],
[1],
)
np_slice = x_arr[:, :, 0:1] np_slice = x_arr[:, :, 0:1]
np.testing.assert_allclose(pp_slice, np_slice) np.testing.assert_allclose(pp_slice, np_slice)
pp_slice = paddle.slice(x, (-100, ), [0], [1]) pp_slice = paddle.slice(x, (-100,), [0], [1])
np_slice = x_arr[0:1] np_slice = x_arr[0:1]
np.testing.assert_allclose(pp_slice, np_slice) np.testing.assert_allclose(pp_slice, np_slice)
...@@ -630,9 +614,11 @@ class TestInferShape(unittest.TestCase): ...@@ -630,9 +614,11 @@ class TestInferShape(unittest.TestCase):
x = paddle.to_tensor(np.reshape(x_arr, (0, 0, 0))) x = paddle.to_tensor(np.reshape(x_arr, (0, 0, 0)))
starts = paddle.to_tensor( starts = paddle.to_tensor(
np.reshape(np.array([], dtype=np.int32), (0, ))) np.reshape(np.array([], dtype=np.int32), (0,))
)
ends = paddle.to_tensor( ends = paddle.to_tensor(
np.reshape(np.array([], dtype=np.int32), (0, ))) np.reshape(np.array([], dtype=np.int32), (0,))
)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.slice(x, [-1000000], starts, ends) paddle.slice(x, [-1000000], starts, ends)
......
...@@ -30,7 +30,6 @@ SEED = 2021 ...@@ -30,7 +30,6 @@ SEED = 2021
class TestSoftmaxWithCrossEntropyOp(OpTest): class TestSoftmaxWithCrossEntropyOp(OpTest):
def set_mlu(self): def set_mlu(self):
self.__class__.use_mlu = True self.__class__.use_mlu = True
...@@ -53,8 +52,10 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -53,8 +52,10 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.initParams() self.initParams()
logits = getattr( logits = getattr(
self, "logits", self,
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype),
)
softmax = np.apply_along_axis(stable_softmax, self.axis, logits) softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
if self.soft_label: if self.soft_label:
...@@ -65,8 +66,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -65,8 +66,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.shape[self.axis] = 1 self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
loss = cross_entropy(softmax, labels, self.soft_label, self.axis, loss = cross_entropy(
self.ignore_index) softmax, labels, self.soft_label, self.axis, self.ignore_index
)
one_hot_label = np.eye(axis_dim)[labels.reshape(-1)] one_hot_label = np.eye(axis_dim)[labels.reshape(-1)]
...@@ -74,7 +76,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -74,7 +76,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.outputs = { self.outputs = {
"Backprop": (softmax - one_hot_label).astype(self.dtype), "Backprop": (softmax - one_hot_label).astype(self.dtype),
"Softmax": softmax.astype(self.dtype), "Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype) "Loss": loss.astype(self.dtype),
} }
self.attrs = { self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode, "numeric_stable_mode": self.numeric_stable_mode,
...@@ -92,14 +94,16 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -92,14 +94,16 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
# fp32 has low precision, cpu and mlu both need to relax the max_relative_error if using fp32 # fp32 has low precision, cpu and mlu both need to relax the max_relative_error if using fp32
self.check_grad_with_place(self.place, ['Logits'], self.check_grad_with_place(
'Loss', self.place,
numeric_grad_delta=0.001, ['Logits'],
max_relative_error=0.5) 'Loss',
numeric_grad_delta=0.001,
max_relative_error=0.5,
)
class TestPowNet(unittest.TestCase): class TestPowNet(unittest.TestCase):
def _test(self, run_mlu=True): def _test(self, run_mlu=True):
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
startup_prog = paddle.static.Program() startup_prog = paddle.static.Program()
...@@ -114,9 +118,9 @@ class TestPowNet(unittest.TestCase): ...@@ -114,9 +118,9 @@ class TestPowNet(unittest.TestCase):
with paddle.static.program_guard(main_prog, startup_prog): with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(name="label", label = paddle.static.data(
shape=[32, 1], name="label", shape=[32, 1], dtype='int64'
dtype='int64') )
sum = paddle.add(a, b) sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0) z = paddle.pow(sum, 2.0)
...@@ -140,16 +144,17 @@ class TestPowNet(unittest.TestCase): ...@@ -140,16 +144,17 @@ class TestPowNet(unittest.TestCase):
print("Start run on {}".format(place)) print("Start run on {}".format(place))
for epoch in range(100): for epoch in range(100):
pred_res, loss_res = exe.run(main_prog, pred_res, loss_res = exe.run(
feed={ main_prog,
"a": a_np, feed={"a": a_np, "b": b_np, "label": label_np},
"b": b_np, fetch_list=[prediction, loss],
"label": label_np )
},
fetch_list=[prediction, loss])
if epoch % 10 == 0: if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format( print(
epoch, pred_res[0], loss_res)) "Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res
)
)
return pred_res, loss_res return pred_res, loss_res
...@@ -157,7 +162,7 @@ class TestPowNet(unittest.TestCase): ...@@ -157,7 +162,7 @@ class TestPowNet(unittest.TestCase):
cpu_pred, cpu_loss = self._test(False) cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True) mlu_pred, mlu_loss = self._test(True)
np.testing.assert_allclose(mlu_pred, cpu_pred, rtol=1e-5) np.testing.assert_allclose(mlu_pred, cpu_pred, rtol=2e-5)
np.testing.assert_allclose(mlu_loss, cpu_loss) np.testing.assert_allclose(mlu_loss, cpu_loss)
......
...@@ -44,17 +44,19 @@ SEED = 10 ...@@ -44,17 +44,19 @@ SEED = 10
class TestSyncBatchNormRunnerBase(object): class TestSyncBatchNormRunnerBase(object):
def get_model(
def get_model(self, self,
main, main,
startup, startup,
place, place,
layout, layout,
seed, seed,
sync_bn=False, sync_bn=False,
only_forward=False): only_forward=False,
):
raise NotImplementedError( raise NotImplementedError(
"get model should be implemented by child class.") "get model should be implemented by child class."
)
def wait_server_ready(self, endpoints): def wait_server_ready(self, endpoints):
assert not isinstance(endpoints, string_types) assert not isinstance(endpoints, string_types)
...@@ -63,13 +65,15 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -63,13 +65,15 @@ class TestSyncBatchNormRunnerBase(object):
not_ready_endpoints = [] not_ready_endpoints = []
for ep in endpoints: for ep in endpoints:
ip_port = ep.split(":") ip_port = ep.split(":")
with closing(socket.socket(socket.AF_INET, with closing(
socket.SOCK_STREAM)) as sock: socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.settimeout(2) sock.settimeout(2)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'): if hasattr(socket, 'SO_REUSEPORT'):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, sock.setsockopt(
1) socket.SOL_SOCKET, socket.SO_REUSEPORT, 1
)
result = sock.connect_ex((ip_port[0], int(ip_port[1]))) result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0: if result != 0:
...@@ -77,39 +81,47 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -77,39 +81,47 @@ class TestSyncBatchNormRunnerBase(object):
not_ready_endpoints.append(ep) not_ready_endpoints.append(ep)
if not all_ok: if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n") sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + sys.stderr.write(
str(not_ready_endpoints) + "\n") "not ready endpoints:" + str(not_ready_endpoints) + "\n"
)
sys.stderr.flush() sys.stderr.flush()
time.sleep(3) time.sleep(3)
else: else:
break break
def initCommunicator(self, program, rank, nranks, wait_port, def initCommunicator(
current_endpoint, endpoints): self, program, rank, nranks, wait_port, current_endpoint, endpoints
):
other_endpoints = endpoints[:] other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint) other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port: if rank == 0 and wait_port:
self.wait_server_ready(other_endpoints) self.wait_server_ready(other_endpoints)
block = program.global_block() block = program.global_block()
cncl_id_var = block.create_var(name=nameGen.generate('cncl_id'), cncl_id_var = block.create_var(
persistable=True, name=nameGen.generate('cncl_id'),
type=core.VarDesc.VarType.RAW) persistable=True,
block.append_op(type='c_gen_cncl_id', type=core.VarDesc.VarType.RAW,
inputs={}, )
outputs={'Out': cncl_id_var}, block.append_op(
attrs={ type='c_gen_cncl_id',
'rank': rank, inputs={},
'endpoint': current_endpoint, outputs={'Out': cncl_id_var},
'other_endpoints': other_endpoints attrs={
}) 'rank': rank,
block.append_op(type='c_comm_init', 'endpoint': current_endpoint,
inputs={'X': cncl_id_var}, 'other_endpoints': other_endpoints,
outputs={}, },
attrs={ )
'nranks': nranks, block.append_op(
'rank': rank, type='c_comm_init',
'ring_id': self.global_ring_id inputs={'X': cncl_id_var},
}) outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': self.global_ring_id,
},
)
def run_trainer(self, args): def run_trainer(self, args):
device_id = int(os.getenv("FLAGS_selected_mlus", "0")) device_id = int(os.getenv("FLAGS_selected_mlus", "0"))
...@@ -127,8 +139,8 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -127,8 +139,8 @@ class TestSyncBatchNormRunnerBase(object):
self._compare(args, place, layout, True) self._compare(args, place, layout, True)
# Test FP16 - @TODO # Test FP16 - @TODO
self.dtype = np.float16 self.bn_dtype = np.float16
self.atol = 1e-2 self.atol = 3e-3
# Test training # Test training
for place in places: for place in places:
...@@ -142,24 +154,30 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -142,24 +154,30 @@ class TestSyncBatchNormRunnerBase(object):
sys.stdout.buffer.write( sys.stdout.buffer.write(
pickle.dumps( pickle.dumps(
'training, inference, fp32, fp16, NCHW, NHWC all passed')) 'training, inference, fp32, fp16, NCHW, NHWC all passed'
)
)
def _compare(self, args, place, layout, only_forward): def _compare(self, args, place, layout, only_forward):
scope = core.Scope() scope = core.Scope()
np.random.seed(SEED) np.random.seed(SEED)
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 data = np.random.random(size=self.dshape).astype(self.dtype) * 4.0 - 2
sys.stderr.write("data: " + str(data) + "\n") sys.stderr.write("data: " + str(data) + "\n")
data = create_or_get_tensor(scope, "input", data = create_or_get_tensor(
OpTest.np_dtype_to_fluid_dtype(data), place) scope, "input", OpTest.np_dtype_to_fluid_dtype(data), place
)
bn_fetches = self._cal_single_card(args, data, place, layout, bn_fetches = self._cal_single_card(
only_forward) args, data, place, layout, only_forward
)
fetch_names, sync_bn_fetches = self._cal_multiple_cards( fetch_names, sync_bn_fetches = self._cal_multiple_cards(
args, data, place, layout, only_forward) args, data, place, layout, only_forward
)
sys.stderr.write("len(sync_bn_fetches): " + str(len(sync_bn_fetches)) + sys.stderr.write(
"\n") "len(sync_bn_fetches): " + str(len(sync_bn_fetches)) + "\n"
)
for i in six.moves.xrange(0, len(sync_bn_fetches)): for i in six.moves.xrange(0, len(sync_bn_fetches)):
sys.stderr.write("i: " + str(i) + "\n") sys.stderr.write("i: " + str(i) + "\n")
sys.stderr.write("fetch_names[i]): " + fetch_names[i] + "\n") sys.stderr.write("fetch_names[i]): " + fetch_names[i] + "\n")
...@@ -167,13 +185,14 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -167,13 +185,14 @@ class TestSyncBatchNormRunnerBase(object):
bn_val = bn_fetches[i] bn_val = bn_fetches[i]
sync_bn_val = sync_bn_fetches[i] sync_bn_val = sync_bn_fetches[i]
if sync_bn_val.shape != bn_val.shape: if sync_bn_val.shape != bn_val.shape:
sync_bn_val = sync_bn_val[:bn_val.shape[0]] sync_bn_val = sync_bn_val[: bn_val.shape[0]]
# i = 0 # i = 0
if fetch_names[i] == 'reduce_sum_0.tmp_0': if fetch_names[i] == 'reduce_sum_0.tmp_0':
# sys.stderr.write("skip reduce_sum_0.tmp_0 (Out of reduce_sum op)" + "\n") # sys.stderr.write("skip reduce_sum_0.tmp_0 (Out of reduce_sum op)" + "\n")
sys.stderr.write("reduce_sum_0.tmp_0 (Out of reduce_sum op)" + sys.stderr.write(
"\n") "reduce_sum_0.tmp_0 (Out of reduce_sum op)" + "\n"
)
sys.stderr.write("bn_val: " + str(bn_val) + "\n") sys.stderr.write("bn_val: " + str(bn_val) + "\n")
sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n")
...@@ -201,7 +220,8 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -201,7 +220,8 @@ class TestSyncBatchNormRunnerBase(object):
if fetch_names[i] == 'batch_norm_0.tmp_2': if fetch_names[i] == 'batch_norm_0.tmp_2':
# sys.stderr.write("skip batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n") # sys.stderr.write("skip batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n")
sys.stderr.write( sys.stderr.write(
"batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n") "batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n"
)
sys.stderr.write("bn_val: " + str(bn_val) + "\n") sys.stderr.write("bn_val: " + str(bn_val) + "\n")
sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n")
...@@ -234,8 +254,9 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -234,8 +254,9 @@ class TestSyncBatchNormRunnerBase(object):
# i = 8 # i = 8
if fetch_names[i] == 'batch_norm_0.tmp_1': if fetch_names[i] == 'batch_norm_0.tmp_1':
sys.stderr.write("skip batch_norm_0.tmp_1 (SavedVariance)" + sys.stderr.write(
"\n") "skip batch_norm_0.tmp_1 (SavedVariance)" + "\n"
)
sys.stderr.write("bn_val: " + str(bn_val) + "\n") sys.stderr.write("bn_val: " + str(bn_val) + "\n")
sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n")
...@@ -281,10 +302,16 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -281,10 +302,16 @@ class TestSyncBatchNormRunnerBase(object):
if fetch_names[i] == 'conv2d_0.tmp_0@GRAD': if fetch_names[i] == 'conv2d_0.tmp_0@GRAD':
atol = 1e-2 atol = 1e-2
assert np.allclose( assert np.allclose(bn_val, sync_bn_val, atol=atol), (
bn_val, sync_bn_val, atol=atol), "Output (" + fetch_names[ "Output ("
i] + ") has diff. \n" + "\nBN " + str( + fetch_names[i]
bn_val) + "\n" + "Sync BN " + str(sync_bn_val) + ") has diff. \n"
+ "\nBN "
+ str(bn_val)
+ "\n"
+ "Sync BN "
+ str(sync_bn_val)
)
def _cal_single_card(self, args, data, place, layout, only_forward): def _cal_single_card(self, args, data, place, layout, only_forward):
# Single-MLU, N = 32 per MLU # Single-MLU, N = 32 per MLU
...@@ -294,23 +321,31 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -294,23 +321,31 @@ class TestSyncBatchNormRunnerBase(object):
startup_prog.global_seed(SEED) startup_prog.global_seed(SEED)
paddle.seed(SEED) paddle.seed(SEED)
outs = self.get_model(train_prog, startup_prog, place, layout, SEED, outs = self.get_model(
False, only_forward) train_prog, startup_prog, place, layout, SEED, False, only_forward
)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
fetch_names = [v.name for v in outs] + [ fetch_names = [v.name for v in outs] + [
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' 'bn_moving_mean',
'bn_moving_variance',
'bn_scale',
'bn_bias',
] ]
if not only_forward: if not only_forward:
others = [ others = [
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', 'batch_norm_0.tmp_0',
'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD' 'batch_norm_0.tmp_1',
'bn_scale@GRAD',
'bn_bias@GRAD',
'batch_norm_0.tmp_3@GRAD',
'conv2d_0.tmp_0@GRAD',
] ]
fetch_names += others fetch_names += others
bn_fetches = exe.run(program=train_prog, bn_fetches = exe.run(
feed={'input': data}, program=train_prog, feed={'input': data}, fetch_list=fetch_names
fetch_list=fetch_names) )
return bn_fetches return bn_fetches
...@@ -331,8 +366,9 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -331,8 +366,9 @@ class TestSyncBatchNormRunnerBase(object):
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
self.initCommunicator(startup_prog, rank, nranks, True, self.initCommunicator(
current_endpoint, endpoints) startup_prog, rank, nranks, True, current_endpoint, endpoints
)
# sys.stderr.write("after init, startup_prog: " + # sys.stderr.write("after init, startup_prog: " +
# startup_prog.to_string(True) + "\n") # startup_prog.to_string(True) + "\n")
train_prog.global_seed(SEED) train_prog.global_seed(SEED)
...@@ -342,8 +378,9 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -342,8 +378,9 @@ class TestSyncBatchNormRunnerBase(object):
paddle.seed(SEED) paddle.seed(SEED)
self.rank = rank self.rank = rank
outs = self.get_model(train_prog, startup_prog, place, layout, SEED, outs = self.get_model(
True, only_forward) train_prog, startup_prog, place, layout, SEED, True, only_forward
)
# sys.stderr.write("after get_model, train_prog: " + # sys.stderr.write("after get_model, train_prog: " +
# train_prog.to_string(True) + "\n") # train_prog.to_string(True) + "\n")
# sys.stderr.write("after get_model, startup_prog: " + # sys.stderr.write("after get_model, startup_prog: " +
...@@ -366,17 +403,24 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -366,17 +403,24 @@ class TestSyncBatchNormRunnerBase(object):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
fetch_names = [v.name for v in outs] + [ fetch_names = [v.name for v in outs] + [
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' 'bn_moving_mean',
'bn_moving_variance',
'bn_scale',
'bn_bias',
] ]
if not only_forward: if not only_forward:
others = [ others = [
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', 'batch_norm_0.tmp_0',
'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD' 'batch_norm_0.tmp_1',
'bn_scale@GRAD',
'bn_bias@GRAD',
'batch_norm_0.tmp_3@GRAD',
'conv2d_0.tmp_0@GRAD',
] ]
fetch_names += others fetch_names += others
sync_bn_fetches = exe.run(program=train_prog, sync_bn_fetches = exe.run(
feed={'input': data}, program=train_prog, feed={'input': data}, fetch_list=fetch_names
fetch_list=fetch_names) )
return fetch_names, sync_bn_fetches return fetch_names, sync_bn_fetches
...@@ -399,19 +443,20 @@ from contextlib import closing ...@@ -399,19 +443,20 @@ from contextlib import closing
class TestDistBase(unittest.TestCase): class TestDistBase(unittest.TestCase):
def setUp(self): def setUp(self):
self._port_set = set() self._port_set = set()
self._trainers = 2 self._trainers = 2
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(),
self._find_free_port(),
)
self._python_interp = sys.executable self._python_interp = sys.executable
def _find_free_port(self): def _find_free_port(self):
def __free_port(): def __free_port():
with closing(socket.socket(socket.AF_INET, with closing(
socket.SOCK_STREAM)) as s: socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as s:
s.bind(('', 0)) s.bind(('', 0))
return s.getsockname()[1] return s.getsockname()[1]
...@@ -440,7 +485,7 @@ class TestDistBase(unittest.TestCase): ...@@ -440,7 +485,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep, "PADDLE_CURRENT_ENDPOINT": w1_ep,
} }
#update environment # update environment
env0.update(envs) env0.update(envs)
env1.update(envs) env1.update(envs)
...@@ -451,15 +496,19 @@ class TestDistBase(unittest.TestCase): ...@@ -451,15 +496,19 @@ class TestDistBase(unittest.TestCase):
tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w") tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w")
print("tr0_cmd: {}, env: {}\n".format(tr0_cmd, env0)) print("tr0_cmd: {}, env: {}\n".format(tr0_cmd, env0))
print("tr1_cmd: {}, env: {}\n".format(tr1_cmd, env1)) print("tr1_cmd: {}, env: {}\n".format(tr1_cmd, env1))
tr0_proc = subprocess.Popen(tr0_cmd.strip().split(), tr0_proc = subprocess.Popen(
stdout=subprocess.PIPE, tr0_cmd.strip().split(),
stderr=tr0_pipe, stdout=subprocess.PIPE,
env=env0) stderr=tr0_pipe,
env=env0,
tr1_proc = subprocess.Popen(tr0_cmd.strip().split(), )
stdout=subprocess.PIPE,
stderr=tr1_pipe, tr1_proc = subprocess.Popen(
env=env1) tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1,
)
tr0_out, tr0_err = tr0_proc.communicate() tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate()
...@@ -473,14 +522,16 @@ class TestDistBase(unittest.TestCase): ...@@ -473,14 +522,16 @@ class TestDistBase(unittest.TestCase):
sys.stderr.write('trainer 0 stderr file: %s\n' % f.read()) sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f: with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f:
sys.stderr.write('trainer 1 stderr file: %s\n' % f.read()) sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
return pickle.loads(tr0_out), pickle.loads( return (
tr1_out), tr0_proc.pid, tr1_proc.pid pickle.loads(tr0_out),
pickle.loads(tr1_out),
def check_with_place(self, tr0_proc.pid,
model_file, tr1_proc.pid,
col_type, )
check_error_log=False,
need_envs={}): def check_with_place(
self, model_file, col_type, check_error_log=False, need_envs={}
):
required_envs = { required_envs = {
"FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_eager_delete_tensor_gb": "0.0", "FLAGS_eager_delete_tensor_gb": "0.0",
...@@ -491,7 +542,7 @@ class TestDistBase(unittest.TestCase): ...@@ -491,7 +542,7 @@ class TestDistBase(unittest.TestCase):
"FLAGS_call_stack_level": "2", "FLAGS_call_stack_level": "2",
"GLOG_v": "3", "GLOG_v": "3",
"PADDLE_WITH_GLOO": '0', "PADDLE_WITH_GLOO": '0',
"BACKEND": "cncl" "BACKEND": "cncl",
} }
required_envs.update(need_envs) required_envs.update(need_envs)
if check_error_log: if check_error_log:
...@@ -499,8 +550,11 @@ class TestDistBase(unittest.TestCase): ...@@ -499,8 +550,11 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs["GLOO_LOG_LEVEL"] = "TRACE" required_envs["GLOO_LOG_LEVEL"] = "TRACE"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster( tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs) model_file, required_envs
)
self.assertEqual( self.assertEqual(
tr0_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed') tr0_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed'
)
self.assertEqual( self.assertEqual(
tr1_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed') tr1_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed'
)
...@@ -29,14 +29,17 @@ paddle.enable_static() ...@@ -29,14 +29,17 @@ paddle.enable_static()
class TestSyncBatchNormOp(TestDistBase): class TestSyncBatchNormOp(TestDistBase):
def _setup_config(self): def _setup_config(self):
pass pass
def test_identity(self, col_type="identity"): def test_identity(self, col_type="identity"):
self.check_with_place("sync_batch_norm_op_mlu.py", envs = {"CNCL_MEM_POOL_MULTI_CLIQUE_ENABLE": "1"}
col_type, self.check_with_place(
check_error_log=True) "sync_batch_norm_op_mlu.py",
col_type,
check_error_log=True,
need_envs=envs,
)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid import core
import paddle.fluid as fluid
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.framework import _test_eager_guard
paddle.enable_static()
def sigmoid(x):
return 1.0 / (1.0 + np.exp(((-1.0) * x)))
def YoloBox(x, img_size, attrs):
(n, c, h, w) = x.shape
anchors = attrs['anchors']
an_num = int((len(anchors) // 2))
class_num = attrs['class_num']
conf_thresh = attrs['conf_thresh']
downsample = attrs['downsample_ratio']
clip_bbox = attrs['clip_bbox']
scale_x_y = attrs['scale_x_y']
iou_aware = attrs['iou_aware']
iou_aware_factor = attrs['iou_aware_factor']
bias_x_y = (-0.5) * (scale_x_y - 1.0)
input_h = downsample * h
input_w = downsample * w
if iou_aware:
ioup = x[:, :an_num, :, :]
ioup = np.expand_dims(ioup, axis=(-1))
x = x[:, an_num:, :, :]
x = x.reshape((n, an_num, (5 + class_num), h, w)).transpose((0, 1, 3, 4, 2))
pred_box = x[:, :, :, :, :4].copy()
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
pred_box[:, :, :, :, 0] = (
(grid_x + (sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y)) + bias_x_y
) / w
pred_box[:, :, :, :, 1] = (
(grid_y + (sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y)) + bias_x_y
) / h
anchors = [
(anchors[i], anchors[(i + 1)]) for i in range(0, len(anchors), 2)
]
anchors_s = np.array(
[((an_w / input_w), (an_h / input_h)) for (an_w, an_h) in anchors]
)
anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w
pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h
if iou_aware:
pred_conf = (sigmoid(x[:, :, :, :, 4:5]) ** (1 - iou_aware_factor)) * (
sigmoid(ioup) ** iou_aware_factor
)
else:
pred_conf = sigmoid(x[:, :, :, :, 4:5])
pred_conf[(pred_conf < conf_thresh)] = 0.0
pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
pred_box = pred_box * (pred_conf > 0.0).astype('float32')
pred_box = pred_box.reshape((n, (-1), 4))
(pred_box[:, :, :2], pred_box[:, :, 2:4]) = (
(pred_box[:, :, :2] - (pred_box[:, :, 2:4] / 2.0)),
(pred_box[:, :, :2] + (pred_box[:, :, 2:4] / 2.0)),
)
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
if clip_bbox:
for i in range(len(pred_box)):
pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf)
pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf)
pred_box[i, :, 2] = np.clip(
pred_box[i, :, 2], (-np.inf), (img_size[(i, 1)] - 1)
)
pred_box[i, :, 3] = np.clip(
pred_box[i, :, 3], (-np.inf), (img_size[(i, 0)] - 1)
)
return (pred_box, pred_score.reshape((n, (-1), class_num)))
class TestYoloBoxOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'yolo_box'
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
self.python_api = paddle.vision.ops.yolo_box
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
'anchors': self.anchors,
'class_num': self.class_num,
'conf_thresh': self.conf_thresh,
'downsample_ratio': self.downsample,
'clip_bbox': self.clip_bbox,
'scale_x_y': self.scale_x_y,
'iou_aware': self.iou_aware,
'iou_aware_factor': self.iou_aware_factor,
}
self.inputs = {'X': x, 'ImgSize': img_size}
(boxes, scores) = YoloBox(x, img_size, self.attrs)
self.outputs = {'Boxes': boxes, 'Scores': scores}
def test_check_output(self):
self.check_output_with_place(self.place, check_eager=False, atol=1e-5)
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (
self.batch_size,
(an_num * (5 + self.class_num)),
13,
13,
)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpNoClipBbox(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (
self.batch_size,
(an_num * (5 + self.class_num)),
13,
13,
)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpScaleXY(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (
self.batch_size,
(an_num * (5 + self.class_num)),
13,
13,
)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.2
self.iou_aware = False
self.iou_aware_factor = 0.5
class TestYoloBoxOpIoUAware(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = True
self.x_shape = (
self.batch_size,
(an_num * (6 + self.class_num)),
13,
13,
)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = True
self.iou_aware_factor = 0.5
class TestYoloBoxDygraph(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
img_size = np.ones((2, 2)).astype('int32')
img_size = paddle.to_tensor(img_size)
x1 = np.random.random([2, 14, 8, 8]).astype('float32')
x1 = paddle.to_tensor(x1)
(boxes, scores) = paddle.vision.ops.yolo_box(
x1,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0,
)
assert (boxes is not None) and (scores is not None)
x2 = np.random.random([2, 16, 8, 8]).astype('float32')
x2 = paddle.to_tensor(x2)
(boxes, scores) = paddle.vision.ops.yolo_box(
x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0,
iou_aware=True,
iou_aware_factor=0.5,
)
paddle.enable_static()
class TestYoloBoxStatic(unittest.TestCase):
def test_static(self):
x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32')
img_size = paddle.static.data('img_size', [2, 2], 'int32')
(boxes, scores) = paddle.vision.ops.yolo_box(
x1,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0,
)
assert (boxes is not None) and (scores is not None)
x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32')
(boxes, scores) = paddle.vision.ops.yolo_box(
x2,
img_size=img_size,
anchors=[10, 13, 16, 30],
class_num=2,
conf_thresh=0.01,
downsample_ratio=8,
clip_bbox=True,
scale_x_y=1.0,
iou_aware=True,
iou_aware_factor=0.5,
)
assert (boxes is not None) and (scores is not None)
class TestYoloBoxOpHW(TestYoloBoxOp):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int((len(self.anchors) // 2))
self.batch_size = 32
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.clip_bbox = False
self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 9)
self.imgsize_shape = (self.batch_size, 2)
self.scale_x_y = 1.0
self.iou_aware = False
self.iou_aware_factor = 0.5
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# A image for building paddle binaries # A image for building paddle binaries
# Update CNTOOLKIT_VERSION, CNNL_VERSION and CNCL_VERSION if using other versions # Update CNTOOLKIT_VERSION, CNNL_VERSION, CNCL_VERSION and MLUOPS_VERSION if using other versions
# #
# Build: # Build:
# - CNTOOLKIT_VERSION 2.8.5 # - CNTOOLKIT_VERSION 3.0.2-1
# - CNNL_VERSION 1.10.5 # - CNNL_VERSION 1.13.0-1
# - CNCL_VERSION 1.1.2 # - CNCL_VERSION 1.2.1-1
# - MLUOPS_VERSION 0.2.0-1
# #
# Download three packages from FTP (need to connect cambricon AE to get FTP url) # Download three packages from FTP (need to connect cambricon AE to get FTP url)
# - cntoolkit_2.8.5.ubuntu18.04_amd64.deb # - cntoolkit_3.0.2-1.ubuntu18.04_amd64.deb
# - cnnl_1.10.5.ubuntu18.04_amd64.deb # - cnnl_1.13.0-1.ubuntu18.04_amd64.deb
# - cncl_1.1.2.ubuntu18.04_amd64.deb # - cncl_1.2.1-1.ubuntu18.04_amd64.deb
# - mluops_0.2.0-1.ubuntu18.04_amd64.deb
# copy them to current directory first, then run build commands # copy them to current directory first, then run build commands
# #
# For example: # For example:
...@@ -19,11 +21,13 @@ ...@@ -19,11 +21,13 @@
# (get cntoolkit pkg) # (get cntoolkit pkg)
# (get cnnl pkg) # (get cnnl pkg)
# (get cncl pkg) # (get cncl pkg)
# (get mluops pkg)
# #
# docker build -f Dockerfile.mlu \ # docker build -f Dockerfile.mlu \
# --build-arg CNTOOLKIT_VERSION=2.8.5 \ # --build-arg CNTOOLKIT_VERSION=3.0.2-1 \
# --build-arg CNNL_VERSION=1.10.5 \ # --build-arg CNNL_VERSION=1.13.0-1 \
# --build-arg CNCL_VERSION=1.1.2 \ # --build-arg CNCL_VERSION=1.2.1-1 \
# --build-arg MLUOPS_VERSION=0.2.0-1 \
# -t paddlepaddle/paddle:latest-dev-mlu . # -t paddlepaddle/paddle:latest-dev-mlu .
# #
# without mlu device: # without mlu device:
...@@ -40,12 +44,14 @@ MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com> ...@@ -40,12 +44,14 @@ MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
ENV WITH_GPU=OFF ENV WITH_GPU=OFF
ARG CNTOOLKIT_VERSION=2.8.5 ARG CNTOOLKIT_VERSION=3.0.2-1
ARG CNNL_VERSION=1.10.5 ARG CNNL_VERSION=1.13.0-1
ARG CNCL_VERSION=1.1.2 ARG CNCL_VERSION=1.2.1-1
ARG MLUOPS_VERSION=0.2.0-1
ARG CNTOOLKIT_PKG=cntoolkit_$CNTOOLKIT_VERSION.ubuntu18.04_amd64.deb ARG CNTOOLKIT_PKG=cntoolkit_$CNTOOLKIT_VERSION.ubuntu18.04_amd64.deb
ARG CNNL_PKG=cnnl_$CNNL_VERSION.ubuntu18.04_amd64.deb ARG CNNL_PKG=cnnl_$CNNL_VERSION.ubuntu18.04_amd64.deb
ARG CNCL_PKG=cncl_$CNCL_VERSION.ubuntu18.04_amd64.deb ARG CNCL_PKG=cncl_$CNCL_VERSION.ubuntu18.04_amd64.deb
ARG MLUOPS_PKG=mluops_$MLUOPS_VERSION.ubuntu18.04_amd64.deb
# install cntoolkit # install cntoolkit
COPY $CNTOOLKIT_PKG ./ COPY $CNTOOLKIT_PKG ./
...@@ -67,6 +73,11 @@ COPY $CNCL_PKG ./ ...@@ -67,6 +73,11 @@ COPY $CNCL_PKG ./
RUN dpkg -i $CNCL_PKG && \ RUN dpkg -i $CNCL_PKG && \
rm -f $CNCL_PKG rm -f $CNCL_PKG
# install mluops
COPY $MLUOPS_PKG ./
RUN dpkg -i $MLUOPS_PKG && \
rm -f $MLUOPS_PKG
# Clean # Clean
RUN apt-get clean -y RUN apt-get clean -y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册