diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 175fa9f94470f86aa75ef63dde043edb0e705b20..95a365f459f18033b9712ed156efe9ef5e6a9faf 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -4274,21 +4274,12 @@ MLURNNDesc::~MLURNNDesc() { /* static */ void MLUCnnl::NumTrue(const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, const void* x, - Tensor index, - uint32_t* num_true) { + const cnnlTensorDescriptor_t num_true_desc, + void* num_true) { cnnlHandle_t handle = GetHandleFromCTX(ctx); - size_t workspace_size = 0; PADDLE_ENFORCE_MLU_SUCCESS( - cnnlGetNumTrueWorkspaceSize(handle, x_desc, &workspace_size)); - - auto& dev_ctx = GetDevCtxFromCTX(ctx); - index = ctx.AllocateTmpTensor( - {static_cast(workspace_size)}, dev_ctx); - void* index_ptr = index.mutable_data(ctx.GetPlace()); - - PADDLE_ENFORCE_MLU_SUCCESS(cnnlNumTrue( - handle, x_desc, x, static_cast(index_ptr), num_true)); + cnnlNumTrue_v2(handle, x_desc, x, num_true_desc, num_true)); } /* static */ void MLUCnnl::Where(const ExecutionContext& ctx, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 0d4c7d2e5a3297ec0b17ac67ba55ef52c62cac84..72446f56a18dc89c4d0abdd4c21532431969e4a6 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1703,8 +1703,8 @@ class MLUCnnl { static void NumTrue(const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, const void* x, - Tensor index, - uint32_t* num_true); + const cnnlTensorDescriptor_t num_true_desc, + void* num_true); static void Where(const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, diff --git a/paddle/fluid/operators/randperm_op_mlu.cc b/paddle/fluid/operators/randperm_op_mlu.cc index 0d4fbf2d12f7cfaa2933ef998478385dd488676c..a3ebf8f5c00fce76c9c2a623f77b4eba7e2be76c 100644 --- a/paddle/fluid/operators/randperm_op_mlu.cc +++ b/paddle/fluid/operators/randperm_op_mlu.cc @@ -15,9 +15,32 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/randperm_op.h" +namespace paddle { +namespace operators { + +template +class RandpermMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int n = ctx.Attr("n"); + unsigned int seed = static_cast(ctx.Attr("seed")); + framework::Variable* out_var = ctx.OutputVar("Out"); + framework::Tensor* out_tensor = + framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var); + + framework::Tensor tmp_tensor; + tmp_tensor.Resize(phi::make_ddim({n})); + T* tmp_data = tmp_tensor.mutable_data(platform::CPUPlace()); + random_permate(tmp_data, n, seed); + framework::TensorCopySync(tmp_tensor, ctx.GetPlace(), out_tensor); + } +}; + +} // namespace operators +} // namespace paddle + template -using kernel = - paddle::operators::RandpermKernel; +using kernel = paddle::operators::RandpermMLUKernel; REGISTER_OP_MLU_KERNEL( randperm, kernel, kernel, kernel, kernel); diff --git a/paddle/fluid/operators/where_index_op_mlu.cc b/paddle/fluid/operators/where_index_op_mlu.cc index d0699521aa46e41a120d1f2347516d93322285f4..389f7960bcdc1595ab9a0da40c0f0ae3f7049b42 100644 --- a/paddle/fluid/operators/where_index_op_mlu.cc +++ b/paddle/fluid/operators/where_index_op_mlu.cc @@ -30,30 +30,36 @@ class MLUWhereIndexKernel : public framework::OpKernel { auto* out = context.Output("Out"); auto dims = condition->dims(); const int rank = dims.size(); - std::vector true_num = {0}; - std::vector vec_condition; - paddle::framework::TensorToVector( - *condition, context.device_context(), &vec_condition); - int vec_con_size = vec_condition.size(); - for (int i = 0; i < vec_con_size; ++i) { - if (vec_condition[i] > 0) true_num[0]++; - } - out->Resize(phi::make_ddim({true_num[0], rank})); + Tensor num_true; + num_true.mutable_data({1}, context.GetPlace()); + MLUCnnlTensorDesc con_desc(*condition); + MLUCnnlTensorDesc num_true_desc(num_true); + MLUCnnl::NumTrue(context, + con_desc.get(), + GetBasePtr(condition), + num_true_desc.get(), + GetBasePtr(&num_true)); + + Tensor local_true_num; + paddle::framework::TensorCopySync( + num_true, platform::CPUPlace(), &local_true_num); + auto true_num = *local_true_num.data(); + + out->Resize(phi::make_ddim({true_num, rank})); out->mutable_data(context.GetPlace()); + + if (true_num == 0) { + return; + } + auto& dev_ctx = context.template device_context(); framework::Tensor out_int32 = context.AllocateTmpTensor(out->dims(), dev_ctx); - Tensor num_true; - paddle::framework::TensorFromVector( - true_num, context.device_context(), &num_true); - num_true.mutable_data(context.GetPlace()); - bool as_tuple = false; - MLUCnnlTensorDesc con_desc(*condition); - MLUCnnlTensorDesc num_true_desc(num_true); MLUCnnlTensorDesc out_int32_desc(out_int32); MLUCnnlTensorDesc out_desc(*out); + bool as_tuple = false; MLUCnnl::Where(context, con_desc.get(), GetBasePtr(condition),