提交 c3871d98 编写于 作者: Y yujianfeng

Add implementation of SparseApplyProximalAdagrad cpu kernel

上级 067616d0
......@@ -632,7 +632,7 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie
}
last_index = index;
}
unique_grad->indices_size_ = unique_indices_size;
unique_grad->indices_size_ = unique_indices_size + 1;
}
} // namespace kernel
} // namespace mindspore
......@@ -22,6 +22,13 @@ namespace {
constexpr size_t kSparseApplyAdamInputSize = 11;
} // namespace
void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
}
void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......@@ -50,7 +57,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
indices_size_ = indices_shape[0];
if (grad_shape[0] != indices_size_) {
MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices";
MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices";
}
if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
......@@ -58,7 +65,7 @@ void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t,
float *v, float beta1, float beta2) {
float *v, float beta1, float beta2) const {
MS_EXCEPTION_IF_NULL(m);
MS_EXCEPTION_IF_NULL(m_t);
MS_EXCEPTION_IF_NULL(v);
......@@ -81,7 +88,7 @@ void SparseApplyAdamCPUKernel::UpdateSparseMomentum(const SparseGradient &unique
}
bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> & /*outputs*/) {
if (inputs.size() < kSparseApplyAdamInputSize) {
MS_LOG(EXCEPTION) << "Error input size!";
......@@ -101,14 +108,12 @@ bool SparseApplyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
auto grad = reinterpret_cast<float *>(inputs[9]->addr);
auto indices = reinterpret_cast<int *>(inputs[10]->addr);
auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
std::vector<float> new_grad;
new_grad.reserve(indices_size_ * var_outer_dim_size_);
std::vector<int> new_indices;
new_indices.reserve(indices_size_);
SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_});
DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_,
var_outer_dim_size_);
SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_,
var_outer_dim_size_);
size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_;
// Update momentum
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
......
......@@ -30,13 +30,13 @@ class SparseApplyAdamCPUKernel : public CPUKernel {
~SparseApplyAdamCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
void InitInputOutputSize(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void UpdateSparseMomentum(const SparseGradient &unique_sparse_grad, float *m, float *m_t, float *v, float beta1,
float beta2);
float beta2) const;
size_t indices_size_{0};
size_t var_first_dim_size_{0};
size_t var_outer_dim_size_{1};
......
......@@ -58,7 +58,7 @@ void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
indices_size_ = indices_shape[0];
if (grad_shape[0] != indices_size_) {
MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices";
MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices";
}
lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr");
if (lr_ <= 0) {
......
......@@ -23,6 +23,13 @@ namespace {
constexpr size_t kSparseApplyLazyAdamInputSize = 11;
} // namespace
void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
}
void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......@@ -51,7 +58,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
indices_size_ = indices_shape[0];
if (grad_shape[0] != indices_size_) {
MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices";
MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices";
}
if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) {
use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "use_nesterov");
......@@ -59,7 +66,7 @@ void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> & /*outputs*/) {
if (inputs.size() < kSparseApplyLazyAdamInputSize) {
MS_LOG(EXCEPTION) << "Error input size!";
......@@ -79,14 +86,12 @@ bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector<kernel::AddressPtr>
auto epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
auto grad = reinterpret_cast<float *>(inputs[9]->addr);
auto indices = reinterpret_cast<int *>(inputs[10]->addr);
auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
std::vector<float> new_grad;
new_grad.reserve(indices_size_ * var_outer_dim_size_);
std::vector<int> new_indices;
new_indices.reserve(indices_size_);
SparseGradient unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size_});
DeduplicateIndexedSlices(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_,
var_outer_dim_size_);
SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_,
var_outer_dim_size_);
lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power);
for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) {
......
......@@ -29,7 +29,7 @@ class SparseApplyLazyAdamCPUKernel : public CPUKernel {
~SparseApplyLazyAdamCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
void InitInputOutputSize(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h"
#include "kernel/common_utils.h"
#include "device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseApplyProximalAdagradInputSize = 7;
} // namespace
void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float));
workspace_size_list_.emplace_back(indices_size_ * sizeof(int));
}
void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<size_t> accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
std::vector<size_t> lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
std::vector<size_t> l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
std::vector<size_t> l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4);
std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5);
std::vector<size_t> indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6);
if (!IsSameShape(var_shape, accum_shape)) {
MS_LOG(EXCEPTION) << "var and accum should have the same shape";
}
if (var_shape.empty()) {
MS_LOG(EXCEPTION) << "var must be at least 1D";
}
var_first_dim_size_ = var_shape[0];
for (size_t i = 1; i < var_shape.size(); ++i) {
if (var_shape[i] != grad_shape[i]) {
MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i;
}
var_outer_dim_size_ *= var_shape[i];
}
if (indices_shape.size() != 1) {
MS_LOG(EXCEPTION) << "indices must be a 1D vector";
}
indices_size_ = indices_shape[0];
if (grad_shape[0] != indices_size_) {
MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices";
}
if (!lr_shape.empty()) {
MS_LOG(EXCEPTION) << "lr is not a scalar";
}
if (!l1_shape.empty()) {
MS_LOG(EXCEPTION) << "l1 is not a scalar";
}
if (!l2_shape.empty()) {
MS_LOG(EXCEPTION) << "l2 is not a scalar";
}
}
bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> & /*outputs*/) {
if (inputs.size() < kSparseApplyProximalAdagradInputSize) {
MS_LOG(EXCEPTION) << "Wrong input size!";
}
auto var = reinterpret_cast<float *>(inputs[0]->addr);
auto accum = reinterpret_cast<float *>(inputs[1]->addr);
auto lr = reinterpret_cast<float *>(inputs[2]->addr)[0];
auto l1 = reinterpret_cast<float *>(inputs[3]->addr)[0];
auto l2 = reinterpret_cast<float *>(inputs[4]->addr)[0];
auto grad = reinterpret_cast<float *>(inputs[5]->addr);
auto indices = reinterpret_cast<int *>(inputs[6]->addr);
auto new_grad = reinterpret_cast<float *>(workspace[0]->addr);
auto new_indices = reinterpret_cast<int *>(workspace[1]->addr);
SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_});
ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_,
var_outer_dim_size_);
for (size_t i = 0; i < unique_sparse_grad.indices_size_; ++i) {
int index = unique_sparse_grad.indices_[i];
if (index < 0 || IntToSize(index) >= var_first_dim_size_) {
MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process";
}
size_t start_index = var_outer_dim_size_ * index;
size_t end_index = start_index + var_outer_dim_size_;
for (size_t j = start_index, k = var_outer_dim_size_ * i; j < end_index; ++j, ++k) {
accum[j] += grad[k] * grad[k];
auto learning_rate = lr * (1 / std::sqrt(accum[j]));
auto prox_v = var[j];
prox_v -= grad[k] * learning_rate;
if (l1 > 0) {
var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast<float>(0.0)) /
(1 + l2 * learning_rate);
} else {
var[j] = prox_v / (1 + l2 * learning_rate);
}
}
}
return true;
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "kernel/cpu/cpu_kernel.h"
#include "kernel/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class SparseApplyProximalAdagradCPUKernel : public CPUKernel {
public:
SparseApplyProximalAdagradCPUKernel() = default;
~SparseApplyProximalAdagradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
void InitInputOutputSize(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
size_t indices_size_{0};
size_t var_first_dim_size_{0};
size_t var_outer_dim_size_{1};
};
MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
SparseApplyProximalAdagradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_
......@@ -21,6 +21,13 @@ from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
beta1_power = 0.9
beta2_power = 0.999
lr = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
class Net(nn.Cell):
def __init__(self):
......@@ -30,7 +37,7 @@ class Net(nn.Cell):
self.m = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="m")
self.v = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="v")
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
def construct(self, grad, indices):
out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon,
grad, indices)
return out
......@@ -42,5 +49,5 @@ def test_net():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
sparse_apply_adam = Net()
output = sparse_apply_adam(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient, indices)
output = sparse_apply_adam(gradient, indices)
print(output[0].asnumpy())
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad()
self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
self.lr = 0.01
self.l1 = 0.0
self.l2 = 0.0
def construct(self, grad, indices):
out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad, indices)
return out
def test_net():
gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
indices = Tensor([0, 1, 2], mstype.int32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
sparse_apply_proximal_adagrad = Net()
output = sparse_apply_proximal_adagrad(gradient, indices)
print(output.asnumpy()[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册