未验证 提交 25124d7f 编写于 作者: F fwenguang 提交者: GitHub

[cherry-pick][MLU] support add callback to stream and profiler (#42115)

* [MLU] add mlu new profiler (#41138)

* [MLU] add mlu new profiler

* fix format

* [MLU] support add callback to stream (#41831)

* [MLU] add gather mlu kernel (#41969)

* [MLU] add mlu activation kernels (#41751)
上级 6c935e1d
......@@ -17,13 +17,16 @@ INCLUDE_DIRECTORIES(${NEUWARE_INCLUDE_DIR})
set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so)
set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so)
set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so)
set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so)
generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake")
set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB})
if(WITH_CNCL)
MESSAGE(STATUS "Compile with CNCL!")
ADD_DEFINITIONS(-DPADDLE_WITH_CNCL)
set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so)
TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
else()
TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
list(APPEND NEUWARE_LIB_DEPS ${CNCL_LIB})
endif()
TARGET_LINK_LIBRARIES(neuware_lib ${NEUWARE_LIB_DEPS})
......@@ -34,14 +34,6 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
return;
}
// NOTE(hqp): Special case for CPU->MLU, avoid stream sync.
if (platform::is_cpu_place(in.place()) && platform::is_mlu_place(dst_place)) {
paddle::framework::TensorCopy(
in, dst_place, *platform::DeviceContextPool::Instance().Get(dst_place),
out);
return;
}
// NOTE(yy): TransDataDevice should wait for computation of input.
if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
......
......@@ -15,12 +15,8 @@ limitations under the Licnse. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
......@@ -38,20 +34,39 @@ class ActivationMLUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->dtype()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->dtype()));
MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(),
reinterpret_cast<const void*>(input->data<T>()),
output_desc.get(),
reinterpret_cast<void*>(output->data<T>()));
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);
MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), GetBasePtr(input),
output_desc.get(), GetBasePtr(output));
}
};
// For gelu, leaky_relu
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernel : public framework::OpKernel<T> {
class ActivationGradMLUKernelV1 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), GetBasePtr(dout), x_desc.get(),
GetBasePtr(x), dx_desc.get(), GetBasePtr(dx));
}
};
// For tanh, sigmoid
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernelV2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
......@@ -61,18 +76,35 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dout->dtype()));
MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(out->dtype()));
MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dx->dtype()));
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(
ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), reinterpret_cast<const void*>(dout->data<T>()),
out_desc.get(), reinterpret_cast<const void*>(out->data<T>()),
dx_desc.get(), reinterpret_cast<void*>(dx->data<T>()));
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, out_desc.get(),
GetBasePtr(out), dout_desc.get(), GetBasePtr(dout),
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
}
};
// For relu, relu6
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernelV3 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), GetBasePtr(dout), out_desc.get(),
GetBasePtr(out), dx_desc.get(), GetBasePtr(dx));
}
};
......@@ -81,10 +113,60 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
// relu
REGISTER_OP_MLU_KERNEL(
relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu_grad, ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);
relu_grad, ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);
// relu6
REGISTER_OP_MLU_KERNEL(
relu6, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU6, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU6, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu6_grad, ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU6, float>,
ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU6,
paddle::platform::float16>);
// sigmoid
REGISTER_OP_MLU_KERNEL(sigmoid,
ops::ActivationMLUKernel<CNNL_ACTIVATION_SIGMOID, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_SIGMOID,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
sigmoid_grad,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_SIGMOID, float>,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_SIGMOID,
paddle::platform::float16>);
// tanh
REGISTER_OP_MLU_KERNEL(
tanh, ops::ActivationMLUKernel<CNNL_ACTIVATION_TANH, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_TANH, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
tanh_grad, ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_TANH, float>,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_TANH,
paddle::platform::float16>);
// gelu
REGISTER_OP_MLU_KERNEL(
gelu, ops::ActivationMLUKernel<CNNL_ACTIVATION_GELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_GELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
gelu_grad, ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_GELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_GELU,
paddle::platform::float16>);
// leaky_relu
REGISTER_OP_MLU_KERNEL(
leaky_relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
leaky_relu_grad,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);
......@@ -51,6 +51,8 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
}
}
}
const T *value_data = &value;
cnnlPointerMode_t pointer_mode = CNNL_POINTER_MODE_HOST;
if (ctx.HasInput("ValueTensor")) {
auto *value_tensor = ctx.Input<framework::Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ(
......@@ -59,22 +61,18 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
"When use Tensor as value to set Tensor value in fill_cosntant, "
"value input(ValueTensor) size must be 1, but get %d",
value_tensor->numel()));
const T *tensor_data = value_tensor->data<T>();
framework::Tensor mlu_tensor;
value_data = value_tensor->data<T>();
auto tmp_place = value_tensor->place();
if (platform::is_mlu_place(tmp_place)) {
framework::TensorCopySync(*value_tensor, platform::CPUPlace(),
&mlu_tensor);
tensor_data = mlu_tensor.data<T>();
pointer_mode = CNNL_POINTER_MODE_DEVICE;
}
value = tensor_data[0];
}
auto shape = GetShape(ctx);
out_var->mutable_data<T>(shape, ctx.GetPlace());
MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(out_var->dtype()));
MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var));
MLUCnnlTensorDesc output_desc(*out_var);
MLUCnnl::Fill(ctx, pointer_mode, value_data, output_desc.get(),
GetBasePtr(out_var));
}
};
} // namespace operators
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto axis = ctx.Attr<int>("axis");
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc index_desc(*index);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(),
GetBasePtr(x), index_desc.get(), GetBasePtr(index),
out_desc.get(), GetBasePtr(out));
}
};
template <typename T>
class GatherGradOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *index = ctx.Input<Tensor>("Index");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dx_desc(*dx);
auto value = static_cast<T>(0);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(),
GetBasePtr(dx));
MLUCnnlTensorDesc index_desc(*index);
MLUCnnlTensorDesc dout_desc(*dout);
const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(),
GetBasePtr(dout), index_desc.get(),
GetBasePtr(index), mode);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(gather, ops::GatherOpMLUKernel<float>,
ops::GatherOpMLUKernel<paddle::platform::float16>,
ops::GatherOpMLUKernel<int>);
REGISTER_OP_MLU_KERNEL(gather_grad, ops::GatherGradOpMLUKernel<float>,
ops::GatherGradOpMLUKernel<paddle::platform::float16>,
ops::GatherGradOpMLUKernel<int>);
......@@ -95,7 +95,8 @@ class MeanMLUGradKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(mean_var.dtype()));
auto value = static_cast<T>(1.0 / static_cast<float>(input_grad->numel()));
MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var));
MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, mean_var_desc.get(),
GetBasePtr(&mean_var));
// means mul output_grad
MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY,
......
......@@ -136,15 +136,17 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// [total]
total->mutable_data<int>(ctx.GetPlace());
MLUCnnlTensorDesc total_desc(*total);
MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples, total_desc.get(),
GetBasePtr(total));
// use `total` of type `float32` for calculating accuracy
Tensor total_fp32(framework::TransToPhiDataType(VT::FP32));
total_fp32.Resize(total->dims());
total_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc total_fp32_desc(total_fp32);
MLUCnnl::Fill(ctx, static_cast<float>(num_samples), total_fp32_desc.get(),
GetBasePtr(&total_fp32));
float num_samples_fp32 = static_cast<float>(num_samples);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples_fp32,
total_fp32_desc.get(), GetBasePtr(&total_fp32));
// [accuracy]
accuracy->mutable_data<float>(ctx.GetPlace());
......
......@@ -208,8 +208,20 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() {
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor(
active_desc_, act_mode, CNNL_NOT_PROPAGATE_NAN, ceof));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4(
active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN, ceof, 1.0f /*sliced_dim*/,
1.67326319217681884765625 /*selu_alpha*/,
1.05070102214813232421875 /*selu_lambda*/));
}
MLUCnnlActivationDesc::MLUCnnlActivationDesc(
const cnnlActivationMode_t act_mode, const float ceof,
const float sliced_dim, const float selu_alpha, const float selu_lambda) {
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4(
active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION,
CNNL_NOT_PROPAGATE_NAN, ceof, sliced_dim, selu_alpha, selu_lambda));
}
const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const {
......@@ -541,12 +553,15 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
output_desc, output));
}
/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx, float value,
/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx,
const cnnlPointerMode_t pointer_mode,
const void* value_ptr,
const cnnlTensorDescriptor_t output_desc,
void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlFill(handle, value, output_desc, output));
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlFill_v3(handle, pointer_mode, value_ptr, output_desc, output));
}
/* static */ void MLUCnnl::QuantifyOffline(
......@@ -919,9 +934,8 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
beta_ptr = static_cast<const void*>(&beta_int);
}
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize_v2(
handle, op_tensor_desc, alpha1_ptr, a_desc, a, alpha2_ptr, b_desc, b,
beta_ptr, output_desc, output, &workspace_size));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize(
handle, a_desc, b_desc, output_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
......
......@@ -218,6 +218,9 @@ class MLUCnnlActivationDesc {
MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete;
MLUCnnlActivationDesc& operator=(const MLUCnnlActivationDesc& desc) = delete;
MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof);
MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof,
const float sliced_dim, const float selu_alpha,
const float selu_lambda);
const cnnlActivationDescriptor_t get() const;
~MLUCnnlActivationDesc();
......@@ -418,7 +421,8 @@ class MLUCnnl {
const cnnlTensorDescriptor_t in1_desc, const void* in1,
const cnnlTensorDescriptor_t output_desc, void* output);
static void Fill(const ExecutionContext& ctx, float value,
static void Fill(const ExecutionContext& ctx,
const cnnlPointerMode_t pointer_mode, const void* value_ptr,
const cnnlTensorDescriptor_t output_desc, void* output);
static void LRN(const ExecutionContext& ctx, const int local_size,
......
......@@ -69,7 +69,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
"the same Tensors."));
}
auto mu = ctx.Attr<float>("mu");
auto mu = static_cast<T>(ctx.Attr<float>("mu"));
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
......@@ -114,14 +114,15 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
Tensor mu_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc mu_tensor_desc(mu_tensor);
MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(),
GetBasePtr(&mu_tensor));
for (size_t idx = 0; idx < n; ++idx) {
RegularizationType regularization_flag =
phi::RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;
? phi::RegularizationType::kL2DECAY
: phi::RegularizationType::kNONE;
T regularization_coeff = static_cast<T>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff = static_cast<T>(regularization_coeffs[idx]);
......@@ -134,7 +135,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel<T> {
auto grad = grads[idx];
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param_out);
if (regularization_flag == RegularizationType::kL2DECAY) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
regularized_grad = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
param_out->dims(), dev_ctx);
MLUCnnlOpTensorDesc op_tensor_desc(
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle {
namespace operators {
......@@ -27,10 +28,10 @@ class MLUMomentumOpKernel : public framework::OpKernel<T> {
std::string regularization_method =
ctx.Attr<std::string>("regularization_method");
auto regularization_coeff = ctx.Attr<float>("regularization_coeff");
RegularizationType regularization_flag{
RegularizationType::kNONE}; // disable regularization
phi::RegularizationType regularization_flag{
phi::RegularizationType::kNONE}; // disable regularization
if (regularization_method == "l2_decay") {
regularization_flag = RegularizationType::kL2DECAY;
regularization_flag = phi::RegularizationType::kL2DECAY;
}
T mu = static_cast<T>(ctx.Attr<float>("mu"));
......@@ -52,11 +53,12 @@ class MLUMomentumOpKernel : public framework::OpKernel<T> {
Tensor mu_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc mu_tensor_desc(mu_tensor);
MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(),
GetBasePtr(&mu_tensor));
Tensor regularized_grad;
MLUCnnlTensorDesc param_desc(*param);
if (regularization_flag == RegularizationType::kL2DECAY) {
if (regularization_flag == phi::RegularizationType::kL2DECAY) {
regularized_grad =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(param->dims(), dev_ctx);
MLUCnnlOpTensorDesc op_tensor_desc(
......
......@@ -116,11 +116,16 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
framework::Tensor extra_device_tensor =
ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(extra_input_size)}, dev_ctx);
// TODO(fwg): use Async copy, and add a callback to stream that free
// host
// memory.
framework::TensorCopySync(extra_host_tensor, ctx.GetPlace(),
&extra_device_tensor);
framework::TensorCopy(extra_host_tensor, ctx.GetPlace(),
&extra_device_tensor);
// Increase extra_host_tensor holder_ reference count until copy
// complete.
auto increase_ref_count = [extra_host_tensor]() {
VLOG(4) << "Finished copying extra_host_tensor["
<< GetBasePtr(&extra_host_tensor)
<< "] in mlu pooling kernel.";
};
dev_ctx.AddStreamCallback(increase_ref_count);
MLUCnnl::PoolingForward(
ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/,
in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/,
......
......@@ -103,8 +103,8 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel<T> {
ToCnnlDataType(input_grad->dtype()));
auto value = static_cast<T>(1.0 / static_cast<float>(reduce_numel));
MLUCnnl::Fill(context, value, input_grad_desc.get(),
GetBasePtr(input_grad));
MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value,
input_grad_desc.get(), GetBasePtr(input_grad));
MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN);
......
......@@ -27,7 +27,7 @@ class ScaleMLUKernel : public framework::OpKernel<T> {
auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
// cnnl require input, scale, bias with same type. And all in device side.
auto& scale = ctx.Attr<float>("scale");
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
framework::Tensor scale_tensor;
if (ctx.HasInput("ScaleTensor")) {
framework::Tensor float_scale_tensor =
......@@ -49,14 +49,16 @@ class ScaleMLUKernel : public framework::OpKernel<T> {
} else {
scale_tensor = ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc scale_desc(scale_tensor);
MLUCnnl::Fill(ctx, scale, scale_desc.get(), GetBasePtr(&scale_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &scale, scale_desc.get(),
GetBasePtr(&scale_tensor));
}
auto& bias = ctx.Attr<float>("bias");
auto bias = static_cast<T>(ctx.Attr<float>("bias"));
framework::Tensor bias_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1}, dev_ctx);
MLUCnnlTensorDesc bias_desc(bias_tensor);
MLUCnnl::Fill(ctx, bias, bias_desc.get(), GetBasePtr(&bias_tensor));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &bias, bias_desc.get(),
GetBasePtr(&bias_tensor));
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<phi::SelectedRows>() && in_var != out_var) {
......
......@@ -16,7 +16,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MLU
#include <cn_api.h>
#include <cndrv_id.h>
#include <cnnl.h>
#include <cnpapi.h>
#include <cnrt.h>
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
......@@ -33,7 +35,7 @@ using cnclStatus = cnclResult_t;
#endif
using mluStream = cnrtQueue_t;
using mluCnnlHandle = cnnlHandle_t;
using mluEventHandle = CNnotifier;
using mluEventHandle = cnrtNotifier_t;
using mluDeviceHandle = CNdev;
namespace platform {
......
......@@ -40,7 +40,6 @@ class MLUStream final {
template <typename Callback>
void AddCallback(Callback&& callback) const {
// TODO(mlu): mlu not support AddCallback
callback_manager_->AddCallback(callback);
}
......
cc_library(host_tracer SRCS host_tracer.cc DEPS enforce)
cc_library(cuda_tracer SRCS cuda_tracer.cc cupti_data_process.cc DEPS workqueue_utils enforce glog)
add_subdirectory(mlu)
cc_library(event_node SRCS event_node.cc DEPS enforce)
cc_library(profiler_utils SRCS utils.cc DEPS enforce glog)
add_subdirectory(dump)
cc_library(profiler_logger SRCS chrometracing_logger.cc dump/serialization_logger.cc dump/deserialization_reader.cc DEPS nodetreeproto event_node profiler_utils)
cc_library(event_bind SRCS event_python.cc DEPS profiler_logger)
cc_library(cpu_utilization SRCS cpu_utilization.cc DEPS cpu_info os_info enforce glog)
cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer profiler_utils cpu_utilization event_bind)
cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer profiler_utils cpu_utilization event_bind mlu_tracer)
cc_test(test_event_node SRCS test_event_node.cc DEPS event_node profiler_logger)
cc_test(test_extra_info SRCS test_extra_info.cc DEPS profiler_utils)
cc_test(test_serialization_logger SRCS dump/test_serialization_logger.cc DEPS event_bind)
......
......@@ -38,10 +38,12 @@ static std::string DefaultFileName() {
}
const char* ChromeTracingLogger::categary_name_[] = {
"Operator", "Dataloader", "ProfileStep", "CudaRuntime",
"Kernel", "Memcpy", "Memset", "UserDefined",
"OperatorInner", "Forward", "Backward", "Optimization",
"Communication", "PythonOp", "PythonUserDefined"};
"Operator", "Dataloader", "ProfileStep",
"CudaRuntime", "Kernel", "Memcpy",
"Memset", "UserDefined", "OperatorInner",
"Forward", "Backward", "Optimization",
"Communication", "PythonOp", "PythonUserDefined",
"MluRuntime"};
void ChromeTracingLogger::OpenFile() {
output_file_stream_.open(filename_,
......@@ -598,6 +600,12 @@ void ChromeTracingLogger::RefineDisplayName(
(*it).second * 2, (*it).first, (*it).second, (*it).second * 2 + 1);
}
#ifdef PADDLE_WITH_MLU
static std::string device_type("MLU");
#else
static std::string device_type("GPU");
#endif
for (auto it = deviceid_streamid_set_.begin();
it != deviceid_streamid_set_.end(); ++it) {
output_file_stream_ << string_format(
......@@ -607,7 +615,7 @@ void ChromeTracingLogger::RefineDisplayName(
"name": "process_name", "pid": %lld, "tid": %lld,
"ph": "M",
"args": {
"name": "Deivce %lld (GPU)"
"name": "Deivce %lld (%s)"
}
},
{
......@@ -632,9 +640,9 @@ void ChromeTracingLogger::RefineDisplayName(
}
},
)JSON"),
(*it).first, (*it).second, (*it).first, (*it).first, (*it).second,
(*it).second, (*it).first, (*it).second, (*it).first + 0x10000000,
(*it).first, (*it).second, (*it).second);
(*it).first, (*it).second, (*it).first, device_type.c_str(),
(*it).first, (*it).second, (*it).second, (*it).first, (*it).second,
(*it).first + 0x10000000, (*it).first, (*it).second, (*it).second);
}
}
......
if(WITH_MLU)
set(MLU_INFO mlu_info)
endif()
cc_library(mlu_tracer SRCS mlu_tracer.cc cnpapi_data_process.cc DEPS workqueue_utils enforce glog ${MLU_INFO})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h"
#include <cstdio>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/os_info.h"
#ifdef PADDLE_WITH_MLU
namespace paddle {
namespace platform {
namespace {
inline uint64_t GetTimeGap() {
static uint64_t time_gap = []() -> uint64_t {
uint64_t cpu_time = PosixInNsec();
uint64_t mlu_time = cnpapiGetTimestamp();
return (cpu_time - mlu_time);
}();
return time_gap;
}
void AddKernelRecord(const cnpapiActivityKernel* kernel, uint64_t start_ns,
TraceEventCollector* collector) {
static uint64_t time_gap = GetTimeGap();
if (kernel->start + time_gap < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = demangle(kernel->name);
event.type = TracerEventType::Kernel;
event.start_ns = kernel->start + time_gap;
event.end_ns = kernel->end + time_gap;
event.device_id = kernel->device_id;
event.context_id = kernel->context_id;
event.stream_id = kernel->queue_id;
event.correlation_id = kernel->correlation_id;
event.kernel_info.block_x = kernel->dimx;
event.kernel_info.block_y = kernel->dimy;
event.kernel_info.block_z = kernel->dimz;
event.kernel_info.grid_x = kernel->kernel_type;
event.kernel_info.grid_y = 0;
event.kernel_info.grid_z = 0;
event.kernel_info.queued = kernel->queued;
event.kernel_info.submitted = kernel->submitted;
event.kernel_info.completed = kernel->received;
collector->AddDeviceEvent(std::move(event));
}
const char* MemcpyKind(cnpapiActivityMemcpyType kind) {
switch (kind) {
case CNPAPI_ACTIVITY_MEMCPY_TYPE_HTOD:
return "MEMCPY_HtoD";
case CNPAPI_ACTIVITY_MEMCPY_TYPE_DTOH:
return "MEMCPY_DtoH";
case CNPAPI_ACTIVITY_MEMCPY_TYPE_DTOD:
return "MEMCPY_DtoD";
case CNPAPI_ACTIVITY_MEMCPY_TYPE_HTOH:
return "MEMCPY_HtoH";
case CNPAPI_ACTIVITY_MEMCPY_TYPE_PTOP:
return "MEMCPY_PtoP";
default:
break;
}
return "MEMCPY";
}
void AddMemcpyRecord(const cnpapiActivityMemcpy* memcpy, uint64_t start_ns,
TraceEventCollector* collector) {
static uint64_t time_gap = GetTimeGap();
if (memcpy->start + time_gap < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = MemcpyKind(memcpy->copy_type);
event.type = TracerEventType::Memcpy;
event.start_ns = memcpy->start + time_gap;
event.end_ns = memcpy->end + time_gap;
event.device_id = memcpy->device_id;
event.context_id = memcpy->context_id;
event.stream_id = memcpy->queue_id;
event.correlation_id = memcpy->correlation_id;
event.memcpy_info.num_bytes = memcpy->bytes;
snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s",
MemcpyKind(memcpy->copy_type));
collector->AddDeviceEvent(std::move(event));
}
void AddMemcpy2Record(const cnpapiActivityMemcpyPtoP* memcpy2,
uint64_t start_ns, TraceEventCollector* collector) {
static uint64_t time_gap = GetTimeGap();
if (memcpy2->start + time_gap < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = MemcpyKind(memcpy2->copy_type);
event.type = TracerEventType::Memcpy;
event.start_ns = memcpy2->start + time_gap;
event.end_ns = memcpy2->end + time_gap;
event.device_id = memcpy2->device_id;
event.context_id = memcpy2->context_id;
event.stream_id = memcpy2->queue_id;
event.correlation_id = memcpy2->correlation_id;
event.memcpy_info.num_bytes = memcpy2->bytes;
snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s",
MemcpyKind(memcpy2->copy_type));
collector->AddDeviceEvent(std::move(event));
}
void AddMemsetRecord(const cnpapiActivityMemset* memset, uint64_t start_ns,
TraceEventCollector* collector) {
static uint64_t time_gap = GetTimeGap();
if (memset->start + time_gap < start_ns) {
return;
}
DeviceTraceEvent event;
event.name = "MEMSET";
event.type = TracerEventType::Memset;
event.start_ns = memset->start + time_gap;
event.end_ns = memset->end + time_gap;
event.device_id = memset->device_id;
event.context_id = memset->context_id;
event.stream_id = memset->queue_id;
event.correlation_id = memset->correlation_id;
event.memset_info.num_bytes = memset->bytes;
event.memset_info.value = memset->value;
collector->AddDeviceEvent(std::move(event));
}
class CnpapiRuntimeCbidStr {
public:
static const CnpapiRuntimeCbidStr& GetInstance() {
static CnpapiRuntimeCbidStr inst;
return inst;
}
std::string RuntimeKind(cnpapi_CallbackId cbid) const {
auto iter = cbid_str_.find(cbid);
if (iter == cbid_str_.end()) {
return "MLU Runtime API " + std::to_string(cbid);
}
return iter->second;
}
private:
CnpapiRuntimeCbidStr();
std::unordered_map<cnpapi_CallbackId, std::string> cbid_str_;
};
CnpapiRuntimeCbidStr::CnpapiRuntimeCbidStr() {
#define REGISTER_RUNTIME_CBID_STR(cbid) \
cbid_str_[CNPAPI_CNDRV_TRACE_CBID_##cbid] = #cbid
REGISTER_RUNTIME_CBID_STR(cnMalloc);
REGISTER_RUNTIME_CBID_STR(cnMallocHost);
REGISTER_RUNTIME_CBID_STR(cnFree);
REGISTER_RUNTIME_CBID_STR(cnFreeHost);
REGISTER_RUNTIME_CBID_STR(cnMemcpy);
REGISTER_RUNTIME_CBID_STR(cnMemcpyPeer);
REGISTER_RUNTIME_CBID_STR(cnMemcpyHtoD);
REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoH);
REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoD);
REGISTER_RUNTIME_CBID_STR(cnMemcpyAsync);
REGISTER_RUNTIME_CBID_STR(cnMemcpyHtoDAsync);
REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoHAsync);
REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoDAsync);
REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoD2D);
REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoD3D);
REGISTER_RUNTIME_CBID_STR(cnMemcpy2D);
REGISTER_RUNTIME_CBID_STR(cnMemcpy3D);
REGISTER_RUNTIME_CBID_STR(cnMemsetD8);
REGISTER_RUNTIME_CBID_STR(cnMemsetD16);
REGISTER_RUNTIME_CBID_STR(cnMemsetD32);
REGISTER_RUNTIME_CBID_STR(cnMemsetD8Async);
REGISTER_RUNTIME_CBID_STR(cnMemsetD16Async);
REGISTER_RUNTIME_CBID_STR(cnMemsetD32Async);
REGISTER_RUNTIME_CBID_STR(cnInvokeKernel);
REGISTER_RUNTIME_CBID_STR(cnCreateQueue);
REGISTER_RUNTIME_CBID_STR(cnDestroyQueue);
REGISTER_RUNTIME_CBID_STR(cnQueueSync);
REGISTER_RUNTIME_CBID_STR(cnQueueWaitNotifier);
REGISTER_RUNTIME_CBID_STR(cnWaitNotifier);
REGISTER_RUNTIME_CBID_STR(cnCreateNotifier);
REGISTER_RUNTIME_CBID_STR(cnDestroyNotifier);
REGISTER_RUNTIME_CBID_STR(cnPlaceNotifier);
REGISTER_RUNTIME_CBID_STR(cnCtxCreate);
REGISTER_RUNTIME_CBID_STR(cnCtxDestroy);
REGISTER_RUNTIME_CBID_STR(cnCtxGetCurrent);
REGISTER_RUNTIME_CBID_STR(cnCtxSetCurrent);
REGISTER_RUNTIME_CBID_STR(cnCtxGetDevice);
REGISTER_RUNTIME_CBID_STR(cnCtxSync);
REGISTER_RUNTIME_CBID_STR(cnInvokeHostFunc);
#undef REGISTER_RUNTIME_CBID_STR
}
void AddApiRecord(const cnpapiActivityAPI* api, uint64_t start_ns,
TraceEventCollector* collector) {
static uint64_t time_gap = GetTimeGap();
if (api->start + time_gap < start_ns) {
return;
}
RuntimeTraceEvent event;
event.name = CnpapiRuntimeCbidStr::GetInstance().RuntimeKind(api->cbid);
event.start_ns = api->start + time_gap;
event.end_ns = api->end + time_gap;
event.process_id = api->process_id;
event.thread_id = api->thread_id;
event.correlation_id = api->correlation_id;
event.callback_id = api->cbid;
event.type = TracerEventType::MluRuntime;
collector->AddRuntimeEvent(std::move(event));
}
} // namespace
namespace details {
void ProcessCnpapiActivityRecord(const cnpapiActivity* record,
uint64_t start_ns,
TraceEventCollector* collector) {
switch (record->type) {
case CNPAPI_ACTIVITY_TYPE_KERNEL:
AddKernelRecord(reinterpret_cast<const cnpapiActivityKernel*>(record),
start_ns, collector);
break;
case CNPAPI_ACTIVITY_TYPE_MEMCPY:
AddMemcpyRecord(reinterpret_cast<const cnpapiActivityMemcpy*>(record),
start_ns, collector);
break;
case CNPAPI_ACTIVITY_TYPE_MEMCPY_PTOP:
AddMemcpy2Record(
reinterpret_cast<const cnpapiActivityMemcpyPtoP*>(record), start_ns,
collector);
break;
case CNPAPI_ACTIVITY_TYPE_MEMSET:
AddMemsetRecord(reinterpret_cast<const cnpapiActivityMemset*>(record),
start_ns, collector);
break;
case CNPAPI_ACTIVITY_TYPE_CNDRV_API:
AddApiRecord(reinterpret_cast<const cnpapiActivityAPI*>(record), start_ns,
collector);
break;
default:
break;
}
}
} // namespace details
} // namespace platform
} // namespace paddle
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif
#include "paddle/fluid/platform/profiler/trace_event_collector.h"
namespace paddle {
namespace platform {
namespace details {
#ifdef PADDLE_WITH_MLU
void ProcessCnpapiActivityRecord(const cnpapiActivity* record,
uint64_t start_ns,
TraceEventCollector* collector);
#endif
} // namespace details
} // namespace platform
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h"
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h"
#define CNPAPI_CALL(call) \
do { \
cnpapiResult _status = call; \
if (_status != CNPAPI_SUCCESS) { \
const char* errstr; \
cnpapiGetResultString(_status, &errstr); \
LOG(ERROR) << "Function " << #call << " failed with error " << errstr; \
} \
} while (0)
namespace paddle {
namespace platform {
namespace {
void BufferRequestedCallback(uint64_t** buffer, size_t* size,
size_t* max_num_records) {
constexpr size_t kBufferSize = 1 << 23; // 8 MB
constexpr size_t kBufferAlignSize = 8;
*buffer = reinterpret_cast<uint64_t*>(
paddle::framework::AlignedMalloc(kBufferSize, kBufferAlignSize));
*size = kBufferSize;
*max_num_records = 0;
}
void BufferCompletedCallback(uint64_t* buffer, size_t size, size_t valid_size) {
if (buffer == nullptr || valid_size == 0) {
return;
}
auto mlu_tracer = &MluTracer::GetInstance();
mlu_tracer->ProcessCnpapiActivity(buffer, valid_size);
paddle::framework::AlignedFree(buffer);
}
} // namespace
MluTracer::MluTracer() {
#ifdef PADDLE_WITH_MLU
CNPAPI_CALL(cnpapiInit());
CNPAPI_CALL(cnpapiActivityRegisterCallbacks(BufferRequestedCallback,
BufferCompletedCallback));
#endif
}
void MluTracer::PrepareTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::UNINITED || state_ == TracerState::STOPED, true,
platform::errors::PreconditionNotMet("MluTracer must be UNINITED"));
EnableCnpapiActivity();
state_ = TracerState::READY;
}
void MluTracer::StartTracing() {
PADDLE_ENFORCE_EQ(state_ == TracerState::READY, true,
platform::errors::PreconditionNotMet(
"MluTracer must be READY or STOPPED"));
tracing_start_ns_ = PosixInNsec();
state_ = TracerState::STARTED;
}
void MluTracer::StopTracing() {
PADDLE_ENFORCE_EQ(
state_, TracerState::STARTED,
platform::errors::PreconditionNotMet("MluTracer must be STARTED"));
DisableCnpapiActivity();
state_ = TracerState::STOPED;
}
void MluTracer::CollectTraceData(TraceEventCollector* collector) {
PADDLE_ENFORCE_EQ(
state_, TracerState::STOPED,
platform::errors::PreconditionNotMet("MluTracer must be STOPED"));
for (auto he : collector_.HostEvents()) {
collector->AddHostEvent(std::move(he));
}
for (auto rte : collector_.RuntimeEvents()) {
collector->AddRuntimeEvent(std::move(rte));
}
for (auto de : collector_.DeviceEvents()) {
collector->AddDeviceEvent(std::move(de));
}
for (auto tn : collector_.ThreadNames()) {
collector->AddThreadName(tn.first, tn.second);
}
collector_.ClearAll();
}
void MluTracer::ProcessCnpapiActivity(uint64_t* buffer, size_t valid_size) {
#ifdef PADDLE_WITH_MLU
cnpapiActivity* record = nullptr;
while (true) {
cnpapiResult status =
cnpapiActivityGetNextRecord(buffer, valid_size, &record);
if (status == CNPAPI_SUCCESS) {
details::ProcessCnpapiActivityRecord(record, tracing_start_ns_,
&collector_);
} else if (status == CNPAPI_ERROR_INSUFFICIENT_MEMORY ||
status == CNPAPI_ERROR_MAX_LIMIT_REACHED) {
break;
} else {
CNPAPI_CALL(status);
}
}
#endif
}
void MluTracer::EnableCnpapiActivity() {
#ifdef PADDLE_WITH_MLU
CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_KERNEL));
CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_MEMCPY));
CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_MEMCPY_PTOP));
CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_MEMSET));
CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_CNDRV_API));
VLOG(3) << "enable cnpapi activity";
#endif
}
void MluTracer::DisableCnpapiActivity() {
#ifdef PADDLE_WITH_MLU
CNPAPI_CALL(cnpapiActivityFlushAll());
CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_KERNEL));
CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_MEMCPY));
CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_MEMCPY_PTOP));
CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_MEMSET));
CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_CNDRV_API));
VLOG(3) << "disable cnpapi activity";
#endif
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/profiler/tracer_base.h"
namespace paddle {
namespace platform {
class MluTracer : public TracerBase {
public:
static MluTracer& GetInstance() {
static MluTracer instance;
return instance;
}
void PrepareTracing() override;
void StartTracing() override;
void StopTracing() override;
void CollectTraceData(TraceEventCollector* collector) override;
void ProcessCnpapiActivity(uint64_t* buffer, size_t valid_size);
private:
MluTracer();
DISABLE_COPY_AND_ASSIGN(MluTracer);
void EnableCnpapiActivity();
void DisableCnpapiActivity();
uint64_t tracing_start_ns_ = UINT64_MAX;
TraceEventCollector collector_;
};
} // namespace platform
} // namespace paddle
......@@ -27,6 +27,7 @@
#include "paddle/fluid/platform/profiler/cuda_tracer.h"
#include "paddle/fluid/platform/profiler/extra_info.h"
#include "paddle/fluid/platform/profiler/host_tracer.h"
#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h"
#include "paddle/fluid/platform/profiler/trace_event_collector.h"
#include "paddle/fluid/platform/profiler/utils.h"
......@@ -52,6 +53,14 @@ bool Profiler::IsCuptiSupported() {
return supported;
}
bool Profiler::IsCnpapiSupported() {
bool supported = false;
#ifdef PADDLE_WITH_MLU
supported = true;
#endif
return supported;
}
Profiler::Profiler(const ProfilerOptions& options) {
options_ = options;
std::bitset<32> trace_switch(options_.trace_switch);
......@@ -63,6 +72,9 @@ Profiler::Profiler(const ProfilerOptions& options) {
if (trace_switch.test(kProfileGPUOptionBit)) {
tracers_.emplace_back(&CudaTracer::GetInstance(), false);
}
if (trace_switch.test(kProfileMLUOptionBit)) {
tracers_.emplace_back(&MluTracer::GetInstance(), false);
}
}
Profiler::~Profiler() { alive_.store(false); }
......
......@@ -33,9 +33,10 @@ namespace platform {
static constexpr uint32_t kProfileCPUOptionBit = 0;
static constexpr uint32_t kProfileGPUOptionBit = 1;
static constexpr uint32_t kProfileMLUOptionBit = 2;
struct ProfilerOptions {
uint32_t trace_switch = 0; // bit 0: cpu, bit 1: gpu
uint32_t trace_switch = 0; // bit 0: cpu, bit 1: gpu, bit 2: mlu
uint32_t trace_level = FLAGS_host_trace_level;
};
......@@ -45,6 +46,8 @@ class Profiler {
static bool IsCuptiSupported();
static bool IsCnpapiSupported();
void Prepare();
void Start();
......
......@@ -50,6 +50,8 @@ enum class TracerEventType {
PythonOp = 13,
// Used to mark python level userdefined
PythonUserDefined = 14,
// Used to mark mlu runtime record returned by cnpapi
MluRuntime = 15,
// A flag to denote the number of current types
NumTypes
};
......
......@@ -52,6 +52,13 @@ class TraceEventCollector {
return thread_names_;
}
void ClearAll() {
thread_names_.clear();
host_events_.clear();
runtime_events_.clear();
device_events_.clear();
}
private:
std::unordered_map<uint64_t, std::string> thread_names_;
std::list<HostTraceEvent> host_events_;
......
......@@ -34,6 +34,10 @@ limitations under the License. */
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif
namespace paddle {
namespace platform {
......@@ -132,6 +136,13 @@ void SynchronizeAllDevice() {
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
}
#endif
#ifdef PADDLE_WITH_MLU
int count = GetMLUDeviceCount();
for (int i = 0; i < count; i++) {
SetMLUDeviceId(i);
PADDLE_ENFORCE_MLU_SUCCESS(cnrtSyncDevice());
}
#endif
}
// Print results
......
......@@ -80,10 +80,8 @@ void StreamCallbackManager<Stream>::AddCallback(
#endif
#if PADDLE_WITH_MLU
VLOG(3) << "MLULaunchCallback at stream: " << stream_
<< " Failed to call MLULaunchCallback, "
<< "because mlu not support StreamAddCallback yet. "
<< "function: " << func;
VLOG(3) << "MLULaunchCallback at stream: " << stream_;
cnrtInvokeHostFunc(stream_, StreamCallbackFunc, func);
#endif
}
......
......@@ -3342,6 +3342,8 @@ All parameter, weight, gradient are variables in Paddle.
.def("create", &paddle::platform::Profiler::Create,
py::return_value_policy::take_ownership)
.def("is_cupti_supported", &paddle::platform::Profiler::IsCuptiSupported)
.def("is_cnpapi_supported",
&paddle::platform::Profiler::IsCnpapiSupported)
.def("prepare",
[](paddle::platform::Profiler *profiler) {
platform::EnableHostEventRecorder();
......
......@@ -47,6 +47,7 @@ enum class Backend : uint8_t {
GPU,
XPU, // XPU currently does not exist at the same time as CUDA
NPU, // NPU currently does not exist at the same time as CUDA
MLU, // MLU currently does not exist at the same time as CUDA
// the third library backend
MKLDNN,
......@@ -114,6 +115,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
case Backend::NPU:
os << "NPU";
break;
case Backend::MLU:
os << "MLU";
break;
case Backend::MKLDNN:
os << "MKLDNN";
break;
......@@ -154,6 +158,8 @@ inline Backend StringToBackend(const char* backend_cstr) {
return Backend::XPU;
} else if (s == std::string("NPU")) {
return Backend::NPU;
} else if (s == std::string("MLU")) {
return Backend::MLU;
} else if (s == std::string("MKLDNN")) {
return Backend::MKLDNN;
} else if (s == std::string("GPUDNN")) {
......
......@@ -41,6 +41,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
return Backend::NPU;
case AllocationType::IPU:
return Backend::IPU;
case AllocationType::MLU:
return Backend::MLU;
case AllocationType::CUSTOM:
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
......
......@@ -273,7 +273,8 @@ def monkey_patch_varbase():
if _grad_scalar:
# When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
self = _grad_scalar.scale(self)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu():
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu(
) or paddle.is_compiled_with_mlu():
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
if framework._in_eager_mode_:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append('..')
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
from paddle.framework import core
from paddle.fluid.dygraph.base import switch_to_static_graph
paddle.enable_static()
def gather_numpy(x, index, axis):
x_transpose = np.swapaxes(x, 0, axis)
tmp_gather = x_transpose[index, ...]
gather = np.swapaxes(tmp_gather, 0, axis)
return gather
class TestGatherOp(OpTest):
def setUp(self):
self.op_type = "gather"
self.place = paddle.MLUPlace(0)
self.__class__.use_mlu = True
self.python_api = paddle.gather
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type)
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase1(TestGatherOp):
def config(self):
"""
For one dimension input
"""
self.x_shape = (100)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase2(TestGatherOp):
def config(self):
"""
For int64_t index type
"""
self.x_shape = (100)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int64"
class API_TestDygraphGather(unittest.TestCase):
def test_out1(self):
paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32')
index_1 = np.array([1, 2])
input = paddle.to_tensor(input_1)
index = paddle.to_tensor(index_1)
output = paddle.fluid.layers.gather(input, index)
output_np = output.numpy()
expected_output = np.array([[3, 4], [5, 6]]).astype('int32')
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
def test_out12(self):
paddle.disable_static()
input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32')
index_1 = np.array([1, 2])
x = paddle.to_tensor(input_1)
index = paddle.to_tensor(index_1)
output = paddle.gather(x, index, axis=0)
output_np = output.numpy()
expected_output = gather_numpy(input_1, index_1, axis=0)
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
def test_zero_index(self):
paddle.disable_static()
x = paddle.to_tensor([[1, 2], [3, 4]]).astype('int32')
index = paddle.to_tensor(np.array([]).astype('int64'))
for axis in range(len(x.shape)):
out = paddle.gather(x, index, axis)
expected_shape = list(x.shape)
expected_shape[axis] = 0
self.assertEqual(list(out.shape), expected_shape)
paddle.enable_static()
class TestGathertError(unittest.TestCase):
def test_error1(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='int8', name='x')
axis = paddle.fluid.data(shape=[1], dtype='float32', name='axis')
index = paddle.fluid.data(shape=shape, dtype='int32', name='index')
index_float = paddle.fluid.data(
shape=shape, dtype='float32', name='index_float')
def test_x_type():
paddle.gather(x, index)
self.assertRaises(TypeError, test_x_type)
def test_index_type():
paddle.gather(x, index_float)
self.assertRaises(TypeError, test_index_type)
def test_axis_dtype():
paddle.gather(x, index, axis=1.11)
self.assertRaises(TypeError, test_axis_dtype)
def test_axis_dtype1():
paddle.gather(x, index, axis=axis)
self.assertRaises(TypeError, test_axis_dtype1)
def test_error2(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
shape = [8, 9, 6]
x = fluid.data(shape=shape, dtype='int8', name='x')
index = fluid.data(shape=shape, dtype='int32', name='mask')
index_float = fluid.data(
shape=shape, dtype='float32', name='index_float')
def test_x_type():
paddle.fluid.layers.gather(x, index)
self.assertRaises(TypeError, test_x_type)
def test_index_type():
paddle.fluid.layers.gather(x, index_float)
self.assertRaises(TypeError, test_index_type)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
from scipy import special
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
def np_gelu(x):
y = 0.5 * x * (1 + special.erf(x / np.sqrt(2)))
return y
class TestGelu(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "gelu"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np_gelu(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.007)
class TestGeluFp16(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "gelu"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np_gelu(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestGeluNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
c = paddle.multiply(a, b)
fc_1 = fluid.layers.fc(input=c, size=128)
fc_1_gelu = fluid.layers.gelu(fc_1)
prediction = fluid.layers.fc(input=fc_1_gelu, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred, atol=1e-3))
self.assertTrue(np.allclose(mlu_loss, cpu_loss, atol=1e-3))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from test_activation_op import ref_leaky_relu
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
class TestLeadyRelu(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "leaky_relu"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
self.set_inputs()
self.set_attrs()
self.set_outputs()
def set_inputs(self):
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
def set_attrs(self):
self.attrs = {}
def set_outputs(self):
alpha = 0.02 if 'alpha' not in self.attrs else self.attrs['alpha']
out = ref_leaky_relu(self.inputs['X'], alpha)
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.006)
else:
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestLeadyReluFP16(TestLeadyRelu):
def init_dtype(self):
self.dtype = np.float16
class TestLeadyRelu2(TestLeadyRelu):
def set_attrs(self):
self.attrs = {'alpha': 0.5}
class TestLeadyRelu3(TestLeadyRelu):
def set_attrs(self):
self.attrs = {'alpha': -0.5}
class TestLeakyReluNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
x_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
x = paddle.static.data(name="x", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
y = paddle.nn.functional.leaky_relu(x)
fc_1 = fluid.layers.fc(input=y, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(main_prog,
feed={"x": x_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle
from op_test import OpTest
import numpy as np
import unittest
import sys
sys.path.append("..")
paddle.enable_static()
SEED = 2021
def ref_relu6(x, threshold=6.0):
out = np.copy(x)
out[np.abs(x - threshold) < 0.005] = threshold + 0.02
out = np.minimum(np.maximum(x, 0), threshold)
return out
class TestRelu6(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "relu6"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_relu6(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': 6.0}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_dtype(self):
self.dtype = np.float32
class TestRelu6Float16(TestRelu6):
def set_mlu(self):
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def set_attrs(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place)
class TestReluNeg(TestRelu6):
def setUp(self):
self.set_mlu()
self.op_type = "relu6"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(-10, -1, [10, 12]).astype(self.dtype)
x[np.abs(x) < 0.005] = 0.02
out = ref_relu6(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'threshold': 6.0}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
class TestRelu6Net(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.nn.functional.relu6(sum)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
class TestMLUSigmoid(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.set_mlu()
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = 1 / (1 + np.exp(-x))
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.01)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
class TestMLUSigmoidFp16(TestMLUSigmoid):
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
class TestTanh(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "tanh"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
self.check_grad(['X'], 'Out', max_relative_error=0.009)
else:
self.check_grad(['X'], 'Out', max_relative_error=0.009)
class TestTanhFp16(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "tanh"
self.place = paddle.MLUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestTanhNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
c = paddle.multiply(a, b)
d = paddle.tanh(c)
fc_1 = fluid.layers.fc(input=d, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
......@@ -53,16 +53,19 @@ class ProfilerState(Enum):
class ProfilerTarget(Enum):
r"""
ProfilerTarget is used to specify target device for :ref:`profiling <api_paddle_profiler_Profiler>` . Only CPU and GPU are supported currently.
ProfilerTarget is used to specify target device for :ref:`profiling <api_paddle_profiler_Profiler>` . Only CPU, GPU and MLU are supported currently.
The meaning of each ProfilerState is as following
- **ProfilerTarget.CPU** : Profile events on CPU.
- **ProfilerTarget.GPU** : Profile events on GPU.
- **ProfilerTarget.MLU** : Profile events on MLU.
"""
CPU = 0
GPU = 1
MLU = 2
def make_scheduler(*,
......@@ -259,6 +262,8 @@ def _get_supported_targets() -> Iterable[ProfilerTarget]:
"""
if _Profiler.is_cupti_supported():
return [ProfilerTarget.CPU, ProfilerTarget.GPU]
if _Profiler.is_cnpapi_supported():
return [ProfilerTarget.CPU, ProfilerTarget.MLU]
return [ProfilerTarget.CPU]
......@@ -267,7 +272,7 @@ class Profiler:
Profiler context manager, user interface to manage profiling process to start, stop, export profiling data and print summary table.
Args:
targets (list, optional): specify target devices to profile, and all existing and supported devices will be chosen by default. Currently supported values, :ref:`ProfilerTarget.CPU <api_paddle_profiler_ProfilerTarget>` and :ref:`ProfilerTarget.GPU <api_paddle_profiler_ProfilerTarget>` .
targets (list, optional): specify target devices to profile, and all existing and supported devices will be chosen by default. Currently supported values, :ref:`ProfilerTarget.CPU <api_paddle_profiler_ProfilerTarget>` , :ref:`ProfilerTarget.GPU <api_paddle_profiler_ProfilerTarget>` and :ref:`ProfilerTarget.MLU <api_paddle_profiler_ProfilerTarget>` .
scheduler (Callable|tuple, optional): If it is a callable object, it takes a step number as parameter and return the corresponding :ref:`ProfilerState <api_paddle_profiler_ProfilerState>`. This callable object can be generated by :ref:`make_scheduler <api_paddle_profiler_make_scheduler>` function.
If not provided (None), the default scheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch,
which means profiling range [start_batch, end_batch).
......@@ -408,6 +413,8 @@ class Profiler:
profileoption.trace_switch |= 1
if ProfilerTarget.GPU in self.targets:
profileoption.trace_switch |= (1 << 1)
if ProfilerTarget.MLU in self.targets:
profileoption.trace_switch |= (1 << 2)
wrap_optimizers()
self.profiler = _Profiler.create(profileoption)
if callable(scheduler):
......
......@@ -105,9 +105,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
place = _current_expected_place()
elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace,
core.CUDAPlace, core.NPUPlace, core.XPUPlace,
core.CustomPlace)):
core.MLUPlace, core.CustomPlace)):
raise ValueError(
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace"
"'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace"
)
if not isinstance(data, np.ndarray):
......
......@@ -2,9 +2,9 @@
# Update CNTOOLKIT_VERSION, CNNL_VERSION and CNCL_VERSION if using other versions
#
# Build:
# - CNTOOLKIT_VERSION 2.6.5-1
# - CNNL_VERSION 1.8.3-1
# - CNCL_VERSION 1.0.2-1
# - CNTOOLKIT_VERSION 2.8.1-1
# - CNNL_VERSION 1.9.3-1
# - CNCL_VERSION 1.0.4-1
#
# Download three packages from FTP (need to connect cambricon AE to get FTP url)
# - cntoolkit_2.6.5-1.ubuntu18.04_amd64.deb
......@@ -21,9 +21,9 @@
# (get cncl pkg)
#
# docker build -f Dockerfile.mlu \
# --build-arg CNTOOLKIT_VERSION=2.6.5-1 \
# --build-arg CNNL_VERSION=1.8.3-1 \
# --build-arg CNCL_VERSION=1.0.2-1 \
# --build-arg CNTOOLKIT_VERSION=2.8.1-1 \
# --build-arg CNNL_VERSION=1.9.3-1 \
# --build-arg CNCL_VERSION=1.0.4-1 \
# -t paddlepaddle/paddle:latest-dev-mlu .
#
# without mlu device:
......@@ -40,9 +40,9 @@ MAINTAINER PaddlePaddle Authors <paddle-dev@baidu.com>
ENV WITH_GPU=OFF
ARG CNTOOLKIT_VERSION=2.6.5-1
ARG CNNL_VERSION=1.8.3-1
ARG CNCL_VERSION=1.0.2-1
ARG CNTOOLKIT_VERSION=2.8.1-1
ARG CNNL_VERSION=1.9.3-1
ARG CNCL_VERSION=1.0.4-1
ARG CNTOOLKIT_PKG=cntoolkit_$CNTOOLKIT_VERSION.ubuntu18.04_amd64.deb
ARG CNNL_PKG=cnnl_$CNNL_VERSION.ubuntu18.04_amd64.deb
ARG CNCL_PKG=cncl_$CNCL_VERSION.ubuntu18.04_amd64.deb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册