提交 62110bec 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1912 add support cpu op sub

Merge pull request !1912 from dengwentao/cpu_op_sub
......@@ -27,7 +27,7 @@ if (ENABLE_CPU)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/subscalar_cpu_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/sub_cpu_kernel.cc")
endif ()
endif ()
......
......@@ -48,19 +48,14 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
MS_LOG(DEBUG) << "output type: " << output_type;
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, "axis");
MS_LOG(DEBUG) << "axis: " << axis;
if (axis_ < 0) {
axis = axis + SizeToInt(input_shape_.size());
}
axis_ = 4 - input_shape_.size() + axis;
axis_ = 4 - input_shape_.size();
MS_LOG(DEBUG) << "axis_: " << axis_;
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
MS_LOG(DEBUG) << "reduce_scatter_flag: " << reduce_scatter_flag_;
if (reduce_scatter_flag_) {
size_t gatherv2_out_lens = 1;
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
if (i == axis) {
if (i == 0) {
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
MS_LOG(DEBUG) << "gatherv2 out shape: " << indices_shape_[j];
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
......@@ -76,7 +71,10 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (gather_v2_out_ == nullptr) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
}
memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_);
auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
}
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
MS_LOG(DEBUG) << "split_num: " << split_num_;
......@@ -99,6 +97,12 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size;
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
if (!reduce_scatter_flag_) {
auto ret = memset_s(gather_out_addr, outputs[0]->size, 0, outputs[0]->size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset out buff failed";
}
}
MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr;
size_t dim0 = input_shape_[0];
size_t dim1 = input_shape_[1];
......@@ -149,10 +153,10 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
return true;
}
void memcpy_task(std::vector<float *> mem_dest_addr_list, std::vector<float *> mem_src_addr_list, size_t start,
void memcpy_task(std::vector<float *> *mem_dest_addr_list, std::vector<float *> *mem_src_addr_list, size_t start,
size_t end, size_t lens) {
for (size_t i = start; i < end; i++) {
auto ret = memcpy_s(mem_dest_addr_list[i], lens, mem_src_addr_list[i], lens);
auto ret = memcpy_s((*mem_dest_addr_list)[i], lens, (*mem_src_addr_list)[i], lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memery copy failed.";
}
......@@ -204,7 +208,7 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr>
break;
}
auto end = (start + ones_copy_lens) > memcpy_lens ? memcpy_lens : start + ones_copy_lens;
threads[i] = std::thread(memcpy_task, mem_dest_addr_list, mem_src_addr_list, start, end, lens);
threads[i] = std::thread(memcpy_task, &mem_dest_addr_list, &mem_src_addr_list, start, end, lens);
start = start + ones_copy_lens;
}
for (size_t j = 0; j < i; j++) {
......
......@@ -14,14 +14,20 @@
* limitations under the License.
*/
#include <thread>
#include "kernel/cpu/subscalar_cpu_kernel.h"
#include "kernel/cpu/sub_cpu_kernel.h"
#include "device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void SubscalarCPUKernel::InitKernel(const CNodePtr &kernel_node) {
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_y");
MS_LOG(DEBUG) << "offset: " << offset_;
void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (shape.size() == 1) {
if (shape[0] != 1) {
MS_LOG(EXCEPTION) << "input 1 only support scalar";
}
} else {
MS_LOG(EXCEPTION) << "input 1 only support scalar";
}
}
void sub_task(int *in_addr, int *out_addr, size_t lens, int offset) {
......@@ -30,9 +36,9 @@ void sub_task(int *in_addr, int *out_addr, size_t lens, int offset) {
}
}
bool SubscalarCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
bool SubCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
......@@ -41,6 +47,8 @@ bool SubscalarCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
#endif
auto input_addr = reinterpret_cast<int *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<int *>(outputs[0]->addr);
offset_ = *reinterpret_cast<int *>(inputs[1]->addr);
MS_LOG(INFO) << "offset: " << offset_;
auto lens = inputs[0]->size / sizeof(int);
if (lens < 10000) {
for (size_t i = 0; i < lens; i++) {
......@@ -73,7 +81,7 @@ bool SubscalarCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
(void)gettimeofday(&end_time, nullptr);
uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "SubscalarCPUKernel, used time: " << time << " us";
MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us";
#endif
return true;
}
......
......@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUBSCALAR_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_SUBSCALAR_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "kernel/cpu/cpu_kernel.h"
......@@ -22,10 +22,10 @@
namespace mindspore {
namespace kernel {
class SubscalarCPUKernel : public CPUKernel {
class SubCPUKernel : public CPUKernel {
public:
SubscalarCPUKernel() : offset_(0) {}
~SubscalarCPUKernel() override = default;
SubCPUKernel() : offset_(0) {}
~SubCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
......@@ -36,9 +36,8 @@ class SubscalarCPUKernel : public CPUKernel {
int offset_;
};
MS_REG_CPU_KERNEL(Subscalar, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SubscalarCPUKernel);
MS_REG_CPU_KERNEL(Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SubCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUBSCALAR_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册