未验证 提交 5061d3db 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix sync copy bugs (#44127)

上级 d55ee95f
......@@ -4274,21 +4274,12 @@ MLURNNDesc::~MLURNNDesc() {
/* static */ void MLUCnnl::NumTrue(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
Tensor index,
uint32_t* num_true) {
const cnnlTensorDescriptor_t num_true_desc,
void* num_true) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size = 0;
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlGetNumTrueWorkspaceSize(handle, x_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
index = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* index_ptr = index.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlNumTrue(
handle, x_desc, x, static_cast<uint32_t*>(index_ptr), num_true));
cnnlNumTrue_v2(handle, x_desc, x, num_true_desc, num_true));
}
/* static */ void MLUCnnl::Where(const ExecutionContext& ctx,
......
......@@ -1703,8 +1703,8 @@ class MLUCnnl {
static void NumTrue(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
const void* x,
Tensor index,
uint32_t* num_true);
const cnnlTensorDescriptor_t num_true_desc,
void* num_true);
static void Where(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t x_desc,
......
......@@ -15,9 +15,32 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/randperm_op.h"
namespace paddle {
namespace operators {
template <typename T>
class RandpermMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int n = ctx.Attr<int>("n");
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
framework::Variable* out_var = ctx.OutputVar("Out");
framework::Tensor* out_tensor =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
framework::Tensor tmp_tensor;
tmp_tensor.Resize(phi::make_ddim({n}));
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
random_permate<T>(tmp_data, n, seed);
framework::TensorCopySync(tmp_tensor, ctx.GetPlace(), out_tensor);
}
};
} // namespace operators
} // namespace paddle
template <typename T>
using kernel =
paddle::operators::RandpermKernel<paddle::platform::MLUDeviceContext, T>;
using kernel = paddle::operators::RandpermMLUKernel<T>;
REGISTER_OP_MLU_KERNEL(
randperm, kernel<int64_t>, kernel<int>, kernel<float>, kernel<double>);
......@@ -30,30 +30,36 @@ class MLUWhereIndexKernel : public framework::OpKernel<T> {
auto* out = context.Output<Tensor>("Out");
auto dims = condition->dims();
const int rank = dims.size();
std::vector<int> true_num = {0};
std::vector<T> vec_condition;
paddle::framework::TensorToVector(
*condition, context.device_context(), &vec_condition);
int vec_con_size = vec_condition.size();
for (int i = 0; i < vec_con_size; ++i) {
if (vec_condition[i] > 0) true_num[0]++;
}
out->Resize(phi::make_ddim({true_num[0], rank}));
Tensor num_true;
num_true.mutable_data<int>({1}, context.GetPlace());
MLUCnnlTensorDesc con_desc(*condition);
MLUCnnlTensorDesc num_true_desc(num_true);
MLUCnnl::NumTrue(context,
con_desc.get(),
GetBasePtr(condition),
num_true_desc.get(),
GetBasePtr(&num_true));
Tensor local_true_num;
paddle::framework::TensorCopySync(
num_true, platform::CPUPlace(), &local_true_num);
auto true_num = *local_true_num.data<int>();
out->Resize(phi::make_ddim({true_num, rank}));
out->mutable_data<int64_t>(context.GetPlace());
if (true_num == 0) {
return;
}
auto& dev_ctx = context.template device_context<MLUDeviceContext>();
framework::Tensor out_int32 =
context.AllocateTmpTensor<int32_t, MLUDeviceContext>(out->dims(),
dev_ctx);
Tensor num_true;
paddle::framework::TensorFromVector(
true_num, context.device_context(), &num_true);
num_true.mutable_data<int>(context.GetPlace());
bool as_tuple = false;
MLUCnnlTensorDesc con_desc(*condition);
MLUCnnlTensorDesc num_true_desc(num_true);
MLUCnnlTensorDesc out_int32_desc(out_int32);
MLUCnnlTensorDesc out_desc(*out);
bool as_tuple = false;
MLUCnnl::Where(context,
con_desc.get(),
GetBasePtr(condition),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册