提交 2601ed30 编写于 作者: S sunsuodong

review

上级 c8f69f5d
......@@ -45,6 +45,7 @@ bool ConcatCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto buff_size = outputs[0]->size;
size_t dim0 = output_shape_[0];
size_t dim1 = output_shape_[1];
size_t dim2 = output_shape_[2];
......@@ -53,28 +54,28 @@ bool ConcatCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
for (size_t k = 0; k < dim2; ++k) {
CopyDataToOutput(inputs, i, j, k, &output_addr);
CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size);
}
}
}
} else if (axis_ == 2) {
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
CopyDataToOutput(inputs, i, j, 0, &output_addr);
CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size);
}
}
} else if (axis_ == 1) {
for (size_t i = 0; i < dim0; ++i) {
CopyDataToOutput(inputs, i, 0, 0, &output_addr);
CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size);
}
} else if (axis_ == 0) {
CopyDataToOutput(inputs, 0, 0, 0, &output_addr);
CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size);
}
return true;
}
void ConcatCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
size_t dim2, float **output_addr) {
size_t dim2, float **output_addr, size_t *buff_size) {
for (size_t i = 0; i < input_shape_list_.size(); ++i) {
auto input_i_shape = input_shape_list_[i];
auto input_i_addr = reinterpret_cast<float *>(inputs[i]->addr);
......@@ -82,11 +83,12 @@ void ConcatCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &in
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_);
num *= input_i_shape[axis_];
auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0);
auto ret = memcpy_s(*output_addr, num * sizeof(float), input_i_addr + pos, num * sizeof(float));
auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed.";
}
*output_addr += num;
*buff_size -= num * sizeof(float);
}
}
......
......@@ -24,7 +24,7 @@ namespace mindspore {
namespace kernel {
class ConcatCPUKernel : public CPUKernel {
public:
ConcatCPUKernel() = default;
ConcatCPUKernel() : axis_(0) {}
~ConcatCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
......@@ -35,7 +35,7 @@ class ConcatCPUKernel : public CPUKernel {
private:
void CheckParam(const CNodePtr &kernel_node);
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
float **output_addr);
float **output_addr, size_t *buff_size);
int axis_;
std::vector<std::vector<size_t>> input_shape_list_;
std::vector<size_t> output_shape_;
......
......@@ -40,7 +40,7 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
auto buff_size = outputs[0]->size;
size_t dim0 = input_shape_[0];
size_t dim1 = input_shape_[1];
size_t dim2 = input_shape_[2];
......@@ -49,29 +49,29 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
for (size_t k = 0; k < dim2; ++k) {
CopyDataToOutput(inputs, i, j, k, &output_addr);
CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size);
}
}
}
} else if (axis_ == 2) {
for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) {
CopyDataToOutput(inputs, i, j, 0, &output_addr);
CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size);
}
}
} else if (axis_ == 1) {
for (size_t i = 0; i < dim0; ++i) {
CopyDataToOutput(inputs, i, 0, 0, &output_addr);
CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size);
}
} else if (axis_ == 0) {
CopyDataToOutput(inputs, 0, 0, 0, &output_addr);
CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size);
}
return true;
}
void GatherV2CPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
size_t dim2, float **output_addr) {
size_t dim2, float **output_addr, size_t *buff_size) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
......@@ -88,11 +88,12 @@ void GatherV2CPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &
pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0);
}
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_);
auto ret = memcpy_s(*output_addr, num * sizeof(float), input_addr + pos, num * sizeof(float));
auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed.";
}
*output_addr += num;
*buff_size -= num * sizeof(float);
}
}
......
......@@ -24,7 +24,7 @@ namespace mindspore {
namespace kernel {
class GatherV2CPUKernel : public CPUKernel {
public:
GatherV2CPUKernel() = default;
GatherV2CPUKernel() : axis_(0) {}
~GatherV2CPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
......@@ -34,7 +34,7 @@ class GatherV2CPUKernel : public CPUKernel {
private:
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
float **output_addr);
float **output_addr, size_t *buff_size);
void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_;
std::vector<size_t> indices_shape_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册