提交 023c9327 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5823 [MSLITE] Support exponent tensor broadcast for power op

Merge pull request !5823 from zhanyuan/dev
...@@ -64,7 +64,9 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:: ...@@ -64,7 +64,9 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_OK; return RET_OK;
} }
if (exp_tensor != nullptr) { if (exp_tensor != nullptr) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) { if ((exp_tensor->shape().size() > 1 && exp_tensor->shape() != x_tensor->shape()) ||
(exp_tensor->shape().size() == 1 && exp_tensor->shape()[0] != 1) ||
exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!"; MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return RET_INPUT_TENSOR_ERROR; return RET_INPUT_TENSOR_ERROR;
} }
......
...@@ -64,11 +64,11 @@ int PowerCPUKernel::RunImpl(int task_id) { ...@@ -64,11 +64,11 @@ int PowerCPUKernel::RunImpl(int task_id) {
bool broadcast = true; bool broadcast = true;
if (in_tensors_.size() == 2) { if (in_tensors_.size() == 2) {
exp_addr = reinterpret_cast<float *>(in_tensors_[1]->Data()); exp_addr = reinterpret_cast<float *>(in_tensors_[1]->Data());
broadcast = false; broadcast = in_tensors_[0]->shape() == in_tensors_[1]->shape() ? false : true;
} }
float *cur_exp = nullptr; float *cur_exp = nullptr;
if (broadcast) { if (broadcast) {
cur_exp = &power_; cur_exp = in_tensors_.size() == 2 ? exp_addr : &power_;
} else { } else {
cur_exp = exp_addr + stride * task_id; cur_exp = exp_addr + stride * task_id;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册