提交 618b734e 编写于 作者: S sunsuodong

applymomentum

上级 ea37dc76
...@@ -71,9 +71,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri ...@@ -71,9 +71,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) { std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (kernel_attr.GetOutputSize() != output_num) {
MS_LOG(EXCEPTION) << "Output num is not equal!";
}
for (size_t output_index = 0; output_index < output_num; ++output_index) { for (size_t output_index = 0; output_index < output_num; ++output_index) {
output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second);
auto dtype = kernel_attr.GetOutputAttr(output_index).first; auto dtype = kernel_attr.GetOutputAttr(output_index).first;
...@@ -145,6 +142,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) { ...@@ -145,6 +142,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
ExpandKernelAttr(kernel_node, &kernel_attr); ExpandKernelAttr(kernel_node, &kernel_attr);
} }
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (kernel_attr.GetOutputSize() != output_num) {
MS_LOG(DEBUG) << "Output num is not equal!";
continue;
}
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
......
...@@ -32,17 +32,17 @@ bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, ...@@ -32,17 +32,17 @@ bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
size_t offset = 0;
for (size_t i = 0; i < output_shape_[0]; ++i) { for (size_t i = 0; i < output_shape_[0]; ++i) {
for (size_t j = 0; j < output_shape_[1]; ++j) { for (size_t j = 0; j < output_shape_[1]; ++j) {
for (size_t k = 0; k < output_shape_[2]; ++k) { for (size_t k = 0; k < output_shape_[2]; ++k) {
for (size_t m = 0; m < output_shape_[3]; ++m) { for (size_t m = 0; m < output_shape_[3]; ++m) {
auto offset = CPUKernelUtils::CalcOffset(output_shape_, i, j, k, m);
float sum = 0; float sum = 0;
for (size_t index = 0; index < input_num_; ++index) { for (size_t index = 0; index < input_num_; ++index) {
auto input_addr = reinterpret_cast<float *>(inputs[index]->addr); auto input_addr = reinterpret_cast<float *>(inputs[index]->addr);
sum += input_addr[offset]; sum += input_addr[offset];
} }
output_addr[offset] = sum; output_addr[offset++] = sum;
} }
} }
} }
......
...@@ -42,6 +42,16 @@ MS_REG_CPU_KERNEL(ApplyMomentum, ...@@ -42,6 +42,16 @@ MS_REG_CPU_KERNEL(ApplyMomentum,
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
ApplyMomentumCPUKernel); ApplyMomentumCPUKernel);
MS_REG_CPU_KERNEL(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ApplyMomentumCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -23,7 +23,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -23,7 +23,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node); CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
for (size_t i = 0; i < begin_.size(); i++) { for (size_t i = 0; i < begin_.size(); i++) {
...@@ -61,6 +60,15 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -61,6 +60,15 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.emplace_back(begin_[i] + sizes[i]); end_.emplace_back(begin_[i] + sizes[i]);
} }
} }
ExpandAllMemberDims();
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
void SliceCPUKernel::ExpandAllMemberDims() {
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
auto input_len = input_shape_.size(); auto input_len = input_shape_.size();
if (input_len < 4) { if (input_len < 4) {
for (size_t i = 0; i < 4 - input_len; ++i) { for (size_t i = 0; i < 4 - input_len; ++i) {
...@@ -70,8 +78,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -70,8 +78,6 @@ void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.insert(end_.begin(), 1); end_.insert(end_.begin(), 1);
} }
} }
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
} }
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
......
...@@ -33,6 +33,7 @@ class SliceCPUKernel : public CPUKernel { ...@@ -33,6 +33,7 @@ class SliceCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void ExpandAllMemberDims();
bool CanCopyMemoryOnAxis(size_t dim) const; bool CanCopyMemoryOnAxis(size_t dim) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset, void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const; const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
......
...@@ -23,7 +23,6 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -23,7 +23,6 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node); CheckParam(kernel_node);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN); begin_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, BEGIN);
for (size_t i = 0; i < begin_.size(); i++) { for (size_t i = 0; i < begin_.size(); i++) {
...@@ -63,6 +62,14 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -63,6 +62,14 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
} }
ExpandAllMemberDims();
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
}
void SliceGradCPUKernel::ExpandAllMemberDims() {
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
auto output_len = output_shape_.size(); auto output_len = output_shape_.size();
if (output_len < 4) { if (output_len < 4) {
for (size_t i = 0; i < 4 - output_len; ++i) { for (size_t i = 0; i < 4 - output_len; ++i) {
...@@ -72,8 +79,6 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { ...@@ -72,8 +79,6 @@ void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
end_.insert(end_.begin(), 1); end_.insert(end_.begin(), 1);
} }
} }
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
} }
bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool SliceGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
......
...@@ -33,6 +33,7 @@ class SliceGradCPUKernel : public CPUKernel { ...@@ -33,6 +33,7 @@ class SliceGradCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void ExpandAllMemberDims();
bool CanCopyMemoryOnAxis(size_t dim) const; bool CanCopyMemoryOnAxis(size_t dim) const;
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset, void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const; const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册