/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/version.h" #include "paddle/fluid/operators/distributed/communicator_common.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/parameter_recv.h" #include "paddle/fluid/string/string_helper.h" namespace paddle { namespace operators { class RecvSaveOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override {} protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::proto::VarType::Type(ctx.Attr("dtype")), platform::CPUPlace()); } }; class RecvSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddComment(R"DOC( Recv Save operator This operator will serialize and write LoDTensor variable to file on disk. )DOC"); AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::proto::VarType::FP32); AddAttr("overwrite", "(boolean, default true)" "Overwrite the output file if exist") .SetDefault(true); AddAttr("file_path", "(string)" "The \"file_path\" where the variable will be saved.") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); AddAttr>("shape", "(vector) The shape of the output") .SetDefault({}); AddAttr>( "slice_varnames", "(string vector, default {}) " "sometimes we need to put received var in another name " "for example: we need var named 'moment_1@127.0.0.1:1001', " "and it real name on parameter server is 'moment_1'. ") .SetDefault({}); AddAttr>( "remote_varnames", "(string vector, default {}) " "sometimes we need to put received var in another name " "for example: we need var named 'moment_1@127.0.0.1:1001', " "and it real name on parameter server is 'moment_1'. ") .SetDefault({}); AddAttr>("slice_shapes", "(vector) " "the length of each output along the " "specified axis.") .SetDefault({}); AddAttr>("endpoints", "(string vector, default 127.0.0.1:6164)" "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr("is_sparse", "sparse or dense param"); AddAttr("pserver_num", "the number of pserver").SetDefault(0); AddAttr("is_distributed", "sparse id range [0, N) or [0, INT64]") .SetDefault(false); } }; template class RecvSaveOpKernel : public framework::OpKernel { private: void SerializeVersionToStream(std::ostream &os) const { { // the 1st field, uint32_t version for LoDTensor os.write(reinterpret_cast(&framework::kCurTensorVersion), sizeof(framework::kCurTensorVersion)); } // the 2st field, LoD information // in this scene, skip LoD information. uint64_t size = 0; os.write(reinterpret_cast(&size), sizeof(size)); } void SerializeTensorHeaderToStream( std::ostream &os, const framework::proto::VarType::Type &type, const framework::DDim &dims) const { { // the 1st field, uint32_t version constexpr uint32_t version = 0; os.write(reinterpret_cast(&version), sizeof(version)); } { // the 2nd field, tensor description // int32_t size // void* protobuf message framework::proto::VarType::TensorDesc desc; desc.set_data_type(type); auto tensor_dims = framework::vectorize(dims); auto *pb_dims = desc.mutable_dims(); pb_dims->Resize(static_cast(tensor_dims.size()), 0); std::copy(tensor_dims.begin(), tensor_dims.end(), pb_dims->begin()); int32_t size = desc.ByteSize(); os.write(reinterpret_cast(&size), sizeof(size)); auto out = desc.SerializeAsString(); os.write(out.data(), size); } } void SerializeTensorAppendToStream(std::ostream &os, const framework::Tensor &tensor) const { uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type()); auto *data_ptr = tensor.data(); PADDLE_ENFORCE_LT(size, std::numeric_limits::max(), platform::errors::ResourceExhausted( "tensor size %d overflow when writing tensor", size)); os.write(static_cast(data_ptr), static_cast(size)); } public: void Compute(const framework::ExecutionContext &ctx) const override { auto filename = ctx.Attr("file_path"); auto overwrite = ctx.Attr("overwrite"); if (FileExists(filename) && !overwrite) { PADDLE_THROW(platform::errors::AlreadyExists( "%s is existed, cannot save to it when overwrite=false", filename)); } MkDirRecursively(DirName(filename).c_str()); auto origin_shape = ctx.Attr>("shape"); auto slice_shapes = ctx.Attr>("slice_shapes"); auto slice_varnames = ctx.Attr>("slice_varnames"); auto remote_varnames = ctx.Attr>("remote_varnames"); auto endpoints = ctx.Attr>("endpoints"); auto trainer_id = ctx.Attr("trainer_id"); auto is_sparse = ctx.Attr("is_sparse"); auto pserver_num = ctx.Attr("pserver_num"); // auto is_distributed = ctx.Attr("is_distributed"); PADDLE_ENFORCE_EQ(slice_shapes.size(), slice_varnames.size(), platform::errors::InvalidArgument( "Expected attr len(slice_shapes) must be equal to " "len(slice_varnames)")); PADDLE_ENFORCE_EQ( slice_shapes.size(), endpoints.size(), platform::errors::InvalidArgument( "Expected attr len(slice_shapes) must be equal to len(endpoints)")); auto data_type = static_cast(ctx.Attr("dtype")); // it to save an output stream. std::ofstream fout(filename, std::ios::binary); PADDLE_ENFORCE_EQ( static_cast(fout), true, platform::errors::NotFound("Cannot open %s to write", filename)); SerializeVersionToStream(fout); SerializeTensorHeaderToStream(fout, data_type, framework::make_ddim(origin_shape)); framework::Scope &local_scope = ctx.scope().NewScope(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto place = ctx.GetPlace(); auto &device_ctx = *pool.Get(place); distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance(trainer_id); if (!is_sparse) { for (size_t i = 0; i < slice_varnames.size(); i++) { auto &varname = slice_varnames[i]; auto *var = local_scope.Var(varname); auto *tensor = var->GetMutable(); auto slice_string = string::split_string(slice_shapes[i], ","); std::vector slice_shape; for (auto &dim : slice_string) { slice_shape.push_back(static_cast(std::stoull(dim))); } tensor->Resize(framework::make_ddim(slice_shape)); distributed::VarHandlePtr ret; ret = rpc_client->AsyncGetVarNoBarrier( endpoints[i], device_ctx, local_scope, remote_varnames[i], varname); PADDLE_ENFORCE_NE( ret->Wait(), 0U, platform::errors::ExecutionTimeout( "rpc error when communication with %s", endpoints[i])); auto &c_tensor = var->Get(); SerializeTensorAppendToStream(fout, c_tensor); local_scope.EraseVars({varname}); } } else { PADDLE_ENFORCE_GT( pserver_num, 0, platform::errors::InvalidArgument( "Expected attr len(pserver_num) must gather than 0")); std::vector varnames; auto *var = local_scope.Var("tmp_for_sparse_merge"); auto *o_t = var->GetMutable(); o_t->Resize(framework::make_ddim(origin_shape)); auto *out_d = o_t->mutable_data(place); varnames.push_back("tmp_for_sparse_merge"); for (size_t i = 0; i < slice_varnames.size(); i++) { varnames.push_back(slice_varnames[i]); } std::vector tensors; for (size_t i = 0; i < slice_varnames.size(); i++) { auto &varname = slice_varnames[i]; auto *local_var = local_scope.Var(varname); auto *tensor = local_var->GetMutable(); auto slice_string = string::split_string(slice_shapes[i], ","); std::vector slice_shape; for (auto &dim : slice_string) { slice_shape.push_back(static_cast(std::stoull(dim))); } tensor->Resize(framework::make_ddim(slice_shape)); distributed::VarHandlePtr ret; ret = rpc_client->AsyncGetVarNoBarrier( endpoints[i], device_ctx, local_scope, remote_varnames[i], varname); PADDLE_ENFORCE_NE( ret->Wait(), 0U, platform::errors::ExecutionTimeout( "rpc error when communication with %s", endpoints[i])); const auto *value = local_var->Get().data(); tensors.push_back(value); } auto dims1 = origin_shape[1]; for (int j = 0; j < origin_shape[0]; ++j) { auto id = j % pserver_num; auto idx = j / pserver_num; std::memcpy(out_d + j * dims1, tensors[id] + idx * dims1, sizeof(float) * dims1); } auto &c_tensor = var->Get(); SerializeTensorAppendToStream(fout, c_tensor); local_scope.EraseVars(varnames); } fout.close(); ctx.scope().DeleteScope(&local_scope); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(recv_save, ops::RecvSaveOp, ops::RecvSaveOpProtoMaker); REGISTER_OP_CPU_KERNEL( recv_save, ops::RecvSaveOpKernel, ops::RecvSaveOpKernel, ops::RecvSaveOpKernel, ops::RecvSaveOpKernel);