提交 03b4de2e 编写于 作者: Z zhanyuan

Support broadcast for Power OP

上级 f0cfc42e
......@@ -207,6 +207,7 @@ class Power : public Primitive {
public:
explicit Power(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::Power *GetAttribute() const { return this->primitive->value_as_Power(); }
int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override;
};
class Range : public Primitive {
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace mindspore::lite {
int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
auto x_tensor = inputs[0];
MS_ASSERT(x_tensor != nullptr);
auto exp_tensor = inputs[1];
MS_ASSERT(exp_tensor != nullptr);
auto output_tensor = outputs[0];
MS_ASSERT(output_tensor != nullptr);
if (inputs.size() < 2) {
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
return RET_INPUT_TENSOR_ERROR;
}
if (exp_tensor->shape() != x_tensor->shape() && exp_tensor->shape().size() != 1) {
MS_LOG(ERROR) << "Power inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
int exp_size = std::accumulate(exp_tensor->shape().begin(), exp_tensor->shape().end(), 1, std::multiplies<int>());
if (x_tensor->data_type() != exp_tensor->data_type() && exp_size != 1) {
MS_LOG(ERROR) << "Exponent tensor's shape is wrong";
return RET_INPUT_TENSOR_ERROR;
}
output_tensor->SetFormat(x_tensor->GetFormat());
output_tensor->set_shape(x_tensor->shape());
output_tensor->set_data_type(x_tensor->data_type());
return RET_OK;
}
} // namespace mindspore::lite
......@@ -50,13 +50,20 @@ int PowerCPUKernel::Run() {
}
int PowerCPUKernel::RunImpl(int task_id) {
auto input_addr = reinterpret_cast<float *>(inputs_.at(0)->Data());
auto output_addr = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto size = inputs_.at(0)->Size();
auto x_addr = reinterpret_cast<float *>(inputs_[0]->Data());
auto exp_addr = reinterpret_cast<float *>(inputs_[1]->Data());
auto output_addr = reinterpret_cast<float *>(outputs_[0]->Data());
auto size = inputs_[0]->ElementsNum();
int stride = UP_DIV(size, thread_count_);
int len = MSMIN(stride, size - stride * task_id);
Power(input_addr + stride * task_id, output_addr + stride * task_id, len, power_, scale_, shift_);
bool broadcast = (inputs_[1]->ElementsNum() == 1) ? true : false;
float *cur_exp;
if (broadcast) {
cur_exp = exp_addr;
} else {
cur_exp = exp_addr + stride * task_id;
}
Power(x_addr + stride * task_id, cur_exp, output_addr + stride * task_id, len, scale_, shift_, broadcast);
return RET_OK;
}
......@@ -67,7 +74,7 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Power);
auto *kernel =
new (std::nothrow) PowerCPUKernel(reinterpret_cast<PowerParameter *>(opParameter), inputs, outputs, ctx);
new (std::nothrow) PowerCPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PowerCPUKernel fail!";
return nullptr;
......
......@@ -18,20 +18,20 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_POWER_H_
#include <vector>
#include "include/context.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/power.h"
namespace mindspore::kernel {
class PowerCPUKernel : public LiteKernel {
public:
PowerCPUKernel(PowerParameter *param, const std::vector<lite::tensor::Tensor *> &inputs,
PowerCPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs),
: LiteKernel(param, inputs, outputs),
ctx_(ctx),
thread_count_(ctx->thread_num_),
power_(param->power_),
scale_(param->scale_),
shift_(param->shift_) {}
scale_(reinterpret_cast<PowerParameter *>(opParameter)->scale_),
shift_(reinterpret_cast<PowerParameter *>(opParameter)->shift_) {}
~PowerCPUKernel() override = default;
int Init() override;
......@@ -40,8 +40,8 @@ class PowerCPUKernel : public LiteKernel {
int RunImpl(int task_id);
private:
const lite::Context *ctx_;
int thread_count_;
float power_;
float scale_;
float shift_;
};
......
......@@ -36,7 +36,8 @@ int PowerGradCPUKernel::Run() {
auto dx_addr = reinterpret_cast<float *>(outputs_.at(0)->Data());
auto size = inputs_.at(0)->ElementsNum();
Power(x_addr, dx_addr, size, power_ - 1, scale_, shift_);
float exp = power_ - 1;
Power(x_addr, &exp, dx_addr, size, scale_, shift_, true);
ElementMul(dx_addr, dy_addr, dx_addr, size);
float scale = scale_ * power_;
for (int i = 0; i < size; i++) {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/power.h"
bool CheckInteger(float f) { return floorf(f) == f; }
float OptimizedPowerImpl(float x, int exponent) {
int exp = abs(exponent);
float result = 1;
float iterator = x;
while (exp) {
if (exp % 2) {
result *= iterator;
}
iterator *= iterator;
exp = exp / 2;
}
return exponent >= 0 ? result : 1 / result;
}
float StdPowerImpl(float x, float exponent) { return pow(x, exponent); }
void Power(const float *input, const float *exponent, float *output, int len, float scale, float shift,
bool broadcast) {
if (broadcast) {
if (CheckInteger(*exponent)) {
for (int i = 0; i < len; ++i) {
output[i] = OptimizedPowerImpl(scale * input[i] + shift, (int)(*exponent));
}
} else {
for (int i = 0; i < len; ++i) {
output[i] = StdPowerImpl(scale * input[i] + shift, *exponent);
}
}
} else {
for (int i = 0; i < len; ++i) {
if (CheckInteger(*exponent)) {
output[i] = OptimizedPowerImpl(scale * input[i] + shift, (int)exponent[i]);
} else {
output[i] = StdPowerImpl(scale * input[i] + shift, exponent[i]);
}
}
}
}
......@@ -26,11 +26,6 @@ struct PowerParameter {
float shift_;
};
inline void Power(const float *input_data, float *output_data, int len, float power, float scale, float shift) {
for (int i = 0; i < len; ++i) {
output_data[i] = pow((scale * input_data[i] + shift), power);
}
}
void Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_POWER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/power.h"
#include "src/kernel_registry.h"
#include "src/lite_kernel.h"
namespace mindspore {
class TestPowerFp32 : public mindspore::Common {
public:
TestPowerFp32() {}
};
int PowerTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
float *a_ptr, float *b_ptr, std::vector<int> a_shape, std::vector<int> b_shape,
std::vector<int> c_shape) {
auto in_t =
new lite::tensor::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum());
inputs_->push_back(in_t);
auto weight_t =
new lite::tensor::Tensor(kNumberTypeFloat, b_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
memcpy(weight_t->Data(), b_ptr, sizeof(float) * weight_t->ElementsNum());
inputs_->push_back(weight_t);
auto out_t =
new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
return out_t->ElementsNum();
}
TEST_F(TestPowerFp32, Simple) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto param = new PowerParameter();
param->scale_ = 1;
param->shift_ = 0;
float a[] = {1, 2, 3, 4};
float b[] = {5, 6, 7, 8};
std::vector<int> a_shape = {2, 2};
std::vector<int> b_shape = {2, 2};
std::vector<int> c_shape = {2, 2};
int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 1;
auto op = new kernel::PowerCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx);
op->Init();
op->Run();
float correct[] = {1, 64, 2187, 65536};
float *output = reinterpret_cast<float *>(outputs_[0]->Data());
for (int i = 0; i < 4; ++i) printf("%f ", output[i]);
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete op;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
}
TEST_F(TestPowerFp32, Broadcast) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto param = new PowerParameter();
param->scale_ = 1;
param->shift_ = 0;
float a[] = {1, 2, 3, 4};
float b[] = {2};
std::vector<int> a_shape = {2, 2};
std::vector<int> b_shape = {1};
std::vector<int> c_shape = {2, 2};
int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
auto op = new kernel::PowerCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx);
op->Init();
op->Run();
float correct[] = {1, 4, 9, 16};
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete op;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册