提交 023c9327 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5823 [MSLITE] Support exponent tensor broadcast for power op

Merge pull request !5823 from zhanyuan/dev
......@@ -64,7 +64,9 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_OK;
}
if (exp_tensor != nullptr) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
if ((exp_tensor->shape().size() > 1 && exp_tensor->shape() != x_tensor->shape()) ||
(exp_tensor->shape().size() == 1 && exp_tensor->shape()[0] != 1) ||
exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
......
......@@ -64,11 +64,11 @@ int PowerCPUKernel::RunImpl(int task_id) {
bool broadcast = true;
if (in_tensors_.size() == 2) {
exp_addr = reinterpret_cast<float *>(in_tensors_[1]->Data());
broadcast = false;
broadcast = in_tensors_[0]->shape() == in_tensors_[1]->shape() ? false : true;
}
float *cur_exp = nullptr;
if (broadcast) {
cur_exp = &power_;
cur_exp = in_tensors_.size() == 2 ? exp_addr : &power_;
} else {
cur_exp = exp_addr + stride * task_id;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册