提交 da9452ee 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3020 Add embedding look up kernels

Merge pull request !3020 from ZPaC/add-ps-embedding-look-up-kernel
......@@ -26,7 +26,10 @@ if (ENABLE_CPU)
"cpu/*.cc"
)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc" "cpu/ps/pull_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc"
"cpu/ps/pull_kernel.cc"
"cpu/ps/embedding_look_up_ps_kernel.cc"
"cpu/ps/embedding_look_up_proxy_kernel.cc")
if (NOT ENABLE_MPI)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc")
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/cpu/ps/embedding_look_up_proxy_kernel.h"
#include <vector>
#include "parallel/ps/worker.h"
namespace mindspore {
namespace kernel {
namespace ps {
void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
EmbeddingLookUpCPUKernel::InitKernel(kernel_node);
for (auto dim : input_shape_) {
input_dims_ *= dim;
}
if (mindspore::parallel::ps::Util::IsRoleOfWorker()) {
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);
}
std::vector<size_t> keys{key_, key_, key_};
std::vector<size_t> values;
values.insert(values.end(), input_shape_.begin(), input_shape_.end());
values.insert(values.end(), indices_shape_.begin(), indices_shape_.end());
values.insert(values.end(), output_shape_.begin(), output_shape_.end());
std::vector<int> lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()),
SizeToInt(output_shape_.size())};
const char *env_role = getenv(mindspore::parallel::ps::kEnvRole);
if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) {
parallel::ps::Worker<float>::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]);
parallel::ps::Worker<float>::GetInstance().InitPSEmbeddingTable(keys, values, lens);
}
}
bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
size_t input_size = inputs[1]->size;
size_t output_size = outputs[0]->size;
size_t size = input_size / sizeof(float);
::ps::SArray<float> lookup_ids(size, 0);
::ps::SArray<int> lengths{size};
::ps::SArray<float> lookup_result;
auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
}
parallel::ps::Worker<float>::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result,
parallel::ps::kEmbeddingLookupCmd);
auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size);
if (ret2 != EOK) {
MS_LOG(EXCEPTION) << "Lookup result memcpy failed.";
}
return true;
}
} // namespace ps
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_
#include "kernel/cpu/embedding_look_up_cpu_kernel.h"
#include <vector>
#include "kernel/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
namespace ps {
class EmbeddingLookUpProxyKernel : public EmbeddingLookUpCPUKernel {
public:
EmbeddingLookUpProxyKernel() = default;
~EmbeddingLookUpProxyKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
size_t key_{0};
size_t input_dims_{1};
};
MS_REG_CPU_KERNEL(
EmbeddingLookupProxy,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
EmbeddingLookUpProxyKernel);
} // namespace ps
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/cpu/ps/embedding_look_up_ps_kernel.h"
#include <functional>
#include <vector>
#include <memory>
#include "kernel/common_utils.h"
#include "parallel/ps/util.h"
namespace mindspore {
namespace kernel {
namespace ps {
using mindspore::parallel::ps::Util;
void EmbeddingLookUpPSKernel::InitKernel(
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
input_shape_ = *(shape_vec[0]);
input_lens_ = 1;
for (auto shape : input_shape_) {
input_lens_ = input_lens_ * shape;
}
indices_shape_ = *(shape_vec[1]);
indices_lens_ = 1;
for (auto shape : indices_shape_) {
indices_lens_ = indices_lens_ * shape;
}
output_shape_ = *(shape_vec[2]);
axis_ = 2;
reduce_scatter_flag_ = false;
size_t offset = 0;
for (size_t i = 0; i < rank_id_; i++) {
offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_);
}
offset_ = offset;
split_num_ = pserver_num_;
// input shape should be sharded after computing offset_;
Shard(input_shape_, axis_);
size_t output_size =
std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies<size_t>());
output_size_list_.emplace_back(output_size);
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}
void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
const std::vector<std::shared_ptr<std::vector<size_t>>> &shape_vec = *shapes;
const auto &indices_shape_ = *(shape_vec[0]);
indices_lens_ = indices_shape_[0];
size_t output_size = sizeof(float) * indices_lens_;
for (size_t i = axis_ + 1; i < input_shape_.size(); i++) {
output_size *= input_shape_[i];
}
output_size_list_.clear();
output_size_list_.emplace_back(output_size);
}
bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
return Launch(inputs, workspace, outputs);
}
const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; }
const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); }
const std::vector<size_t> &EmbeddingLookUpPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); }
} // namespace ps
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_
#include <vector>
#include <memory>
#include "kernel/cpu/embedding_look_up_cpu_kernel.h"
#include "kernel/cpu/ps/pserver_kernel.h"
namespace mindspore {
namespace kernel {
namespace ps {
class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel {
public:
EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {}
~EmbeddingLookUpPSKernel() override = default;
void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override;
void ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override;
bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<size_t> &input_sizes() const override;
const std::vector<size_t> &output_sizes() const override;
const std::vector<size_t> &workspace_sizes() const override;
};
} // namespace ps
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册