提交 618b734e 编写于 作者: S sunsuodong

applymomentum

上级 ea37dc76
......@@ -71,9 +71,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (kernel_attr.GetOutputSize() != output_num) {
MS_LOG(EXCEPTION) << "Output num is not equal!";
}
for (size_t output_index = 0; output_index < output_num; ++output_index) {
output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second);
auto dtype = kernel_attr.GetOutputAttr(output_index).first;
......@@ -145,6 +142,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
ExpandKernelAttr(kernel_node, &kernel_attr);
}
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (kernel_attr.GetOutputSize() != output_num) {
MS_LOG(DEBUG) << "Output num is not equal!";
continue;
}
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
......
......@@ -32,17 +32,17 @@ bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
size_t offset = 0;
for (size_t i = 0; i < output_shape_[0]; ++i) {
for (size_t j = 0; j < output_shape_[1]; ++j) {
for (size_t k = 0; k < output_shape_[2]; ++k) {
for (size_t m = 0; m < output_shape_[3]; ++m) {
auto offset = CPUKernelUtils::CalcOffset(output_shape_, i, j, k, m);
float sum = 0;
for (size_t index = 0; index < input_num_; ++index) {
auto input_addr = reinterpret_cast<float *>(inputs[index]->addr);
sum += input_addr[offset];
}
output_addr[offset] = sum;
output_addr[offset++] = sum;
}
}
}
......
......@@ -42,6 +42,16 @@ MS_REG_CPU_KERNEL(ApplyMomentum,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ApplyMomentumCPUKernel);
MS_REG_CPU_KERNEL(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ApplyMomentumCPUKernel);
} // namespace kernel
} // namespace mindspore
......
......@@ -23,7 +23,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
for (size_t i = 0; i < begin_.size(); i++) {
......@@ -61,6 +60,15 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.emplace_back(begin_[i] + sizes[i]);
}
}
ExpandAllMemberDims();
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
void SliceCPUKernel::ExpandAllMemberDims() {
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
auto input_len = input_shape_.size();
if (input_len < 4) {
for (size_t i = 0; i < 4 - input_len; ++i) {
......@@ -70,8 +78,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.insert(end_.begin(), 1);
}
}
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
......
......@@ -33,6 +33,7 @@ class SliceCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
void ExpandAllMemberDims();
bool CanCopyMemoryOnAxis(size_t dim) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
......
......@@ -23,7 +23,6 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
for (size_t i = 0; i < begin_.size(); i++) {
......@@ -63,6 +62,14 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
}
ExpandAllMemberDims();
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
void SliceGradCPUKernel::ExpandAllMemberDims() {
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
auto output_len = output_shape_.size();
if (output_len < 4) {
for (size_t i = 0; i < 4 - output_len; ++i) {
......@@ -72,8 +79,6 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.insert(end_.begin(), 1);
}
}
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
......
......@@ -33,6 +33,7 @@ class SliceGradCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
void ExpandAllMemberDims();
bool CanCopyMemoryOnAxis(size_t dim) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册