// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/framework/fleet/box_wrapper.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" namespace paddle { namespace operators { template static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) { auto inputs = ctx.MultiInput("Ids"); auto outputs = ctx.MultiOutput("Out"); auto hidden_size = ctx.Attr("size"); const auto slot_size = inputs.size(); std::vector all_keys(slot_size); // BoxPS only supports float now std::vector all_values(slot_size); std::vector slot_lengths(slot_size); for (size_t i = 0; i < slot_size; i++) { const auto *slot = inputs[i]; const uint64_t *single_slot_keys = reinterpret_cast(slot->data()); all_keys[i] = single_slot_keys; slot_lengths[i] = slot->numel(); auto *output = outputs[i]->mutable_data(ctx.GetPlace()); all_values[i] = output; } auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths, hidden_size); } template static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) { auto inputs = ctx.MultiInput("Ids"); auto d_output = ctx.MultiInput(framework::GradVarName("Out")); auto hidden_size = ctx.Attr("size"); const auto slot_size = inputs.size(); std::vector all_keys(slot_size); std::vector all_grad_values(slot_size); std::vector slot_lengths(slot_size); for (size_t i = 0; i < slot_size; i++) { const auto *slot = inputs[i]; const uint64_t *single_slot_keys = reinterpret_cast(slot->data()); all_keys[i] = single_slot_keys; slot_lengths[i] = slot->numel(); const float *grad_value = d_output[i]->data(); all_grad_values[i] = grad_value; } auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values, slot_lengths, hidden_size); } using LoDTensor = framework::LoDTensor; template class PullBoxSparseCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PullBoxSparseFunctor(ctx); } }; template class PushBoxSparseCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { PushBoxSparseFunctor(ctx); } }; } // namespace operators } // namespace paddle