未验证 提交 9ad940fd 编写于 作者: T tangwei12 提交者: GitHub

memory leak for cpu (#21174)

* add fake init for the trainer, fix large memory hold in the trainer
* do not merge recv vars from a remote endpoint, test=develop
* add recv and save op, merge slice var in one op, save memory
* remove hsigmoid with pull sparse, test=develop
上级 03133c2c
...@@ -243,15 +243,48 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor, ...@@ -243,15 +243,48 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
TensorToStream(os, static_cast<Tensor>(tensor), dev_ctx); TensorToStream(os, static_cast<Tensor>(tensor), dev_ctx);
} }
void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
const platform::DeviceContext &dev_ctx,
const size_t &seek,
const std::vector<int64_t> &shape) {
{
// the 1st field, unit32_t version for LoDTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(framework::IsTensorVersionSupported(version), true,
platform::errors::InvalidArgument(
"tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported",
version));
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
}
// the 3st filed, Tensor
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx, seek, shape);
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor, void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx) {
{ {
// the 1st field, unit32_t version for LoDTensor // the 1st field, unit32_t version for LoDTensor
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE(framework::IsTensorVersionSupported(version), PADDLE_ENFORCE_EQ(framework::IsTensorVersionSupported(version), true,
"tensor version %u is not supported.", version); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); "tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported",
version));
} }
{ {
// the 2st field, LoD information // the 2st field, LoD information
......
...@@ -209,6 +209,10 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, ...@@ -209,6 +209,10 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is, LoDTensor* tensor, void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx,
const size_t& seek,
const std::vector<int64_t>& shape);
/* /*
* Convert between length-based LoD and offset-based LoD. * Convert between length-based LoD and offset-based LoD.
......
...@@ -342,8 +342,9 @@ bool LoadTensorFromDisk( ...@@ -342,8 +342,9 @@ bool LoadTensorFromDisk(
std::unique_ptr<char[]> buf(new char[size]); std::unique_ptr<char[]> buf(new char[size]);
fin.read(reinterpret_cast<char*>(buf.get()), size); fin.read(reinterpret_cast<char*>(buf.get()), size);
CheckInStreamState(fin, sizeof(size)); CheckInStreamState(fin, sizeof(size));
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), PADDLE_ENFORCE_EQ(
"Cannot parse tensor desc"); desc.ParseFromArray(buf.get(), size), true,
platform::errors::InvalidArgument("Cannot parse tensor desc"));
} }
{ // read tensor { // read tensor
......
...@@ -404,8 +404,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -404,8 +404,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type()); uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
auto* data_ptr = tensor.data<void>(); auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(), PADDLE_ENFORCE_LT(size, std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor"); platform::errors::ResourceExhausted(
"tensor size %d overflow when writing tensor", size));
if (platform::is_gpu_place(tensor.place())) { if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
...@@ -426,7 +427,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -426,7 +427,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
size -= size_to_write; size -= size_to_write;
} }
#else #else
PADDLE_THROW("Unexpected branch"); PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
#endif #endif
} else { } else {
os.write(static_cast<const char*>(data_ptr), os.write(static_cast<const char*>(data_ptr),
...@@ -450,11 +452,69 @@ struct DeserializedDataFunctor { ...@@ -450,11 +452,69 @@ struct DeserializedDataFunctor {
platform::Place place_; platform::Place place_;
}; };
void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx,
const size_t& seek, const std::vector<int64_t>& shape) {
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported",
version));
proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char*>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size);
PADDLE_ENFORCE_EQ(
desc.ParseFromArray(buf.get(), size), true,
platform::errors::InvalidArgument("Cannot parse tensor desc"));
}
{ // read tensor
tensor->Resize(framework::make_ddim(shape));
size_t seekg = seek * framework::SizeOfType(desc.data_type());
is.seekg(seekg, is.cur);
void* buf;
auto ctx = platform::CPUDeviceContext();
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(shape));
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
#endif
} else {
framework::VisitDataType(
desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), size);
}
}
}
void TensorFromStream(std::istream& is, Tensor* tensor, void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version)); is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"tensor version %u is not supported, Only version 0 is supported",
version));
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
...@@ -462,8 +522,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -462,8 +522,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(reinterpret_cast<char*>(&size), sizeof(size)); is.read(reinterpret_cast<char*>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]); std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size); is.read(reinterpret_cast<char*>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size), PADDLE_ENFORCE_EQ(
"Cannot parse tensor desc"); desc.ParseFromArray(buf.get(), size), true,
platform::errors::InvalidArgument("Cannot parse tensor desc"));
} }
{ // read tensor { // read tensor
std::vector<int64_t> dims; std::vector<int64_t> dims;
...@@ -484,7 +545,8 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -484,7 +545,8 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
auto dst_place = dev_ctx.GetPlace(); auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
#else #else
PADDLE_THROW("Unexpected branch"); PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
#endif #endif
} else { } else {
framework::VisitDataType( framework::VisitDataType(
......
...@@ -72,6 +72,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -72,6 +72,9 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
void TensorFromStream(std::istream& is, Tensor* tensor, void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
void TensorFromStream(std::istream& is, Tensor* tensor,
const platform::DeviceContext& dev_ctx,
const size_t& seek, const std::vector<int64_t>& shape);
// convert dlpack's DLTensor to tensor // convert dlpack's DLTensor to tensor
void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst); void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst);
......
...@@ -183,9 +183,18 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -183,9 +183,18 @@ void prefetchs(const std::vector<std::string>& id_var_names,
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), ""); PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), ""); PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
auto* reconstruct_var = auto vec_dim_1 = 0;
scope.FindVar(persistable_var_name)->GetMutable<framework::LoDTensor>(); framework::Variable* var = scope.FindVar(persistable_var_name);
const auto vec_dim_1 = reconstruct_var->dims()[1];
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"prefetch can only support LodTensor only"));
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
PADDLE_ENFORCE_GT(vec_dim_1, 0,
platform::errors::InvalidArgument(
"lookup table var's dim must gather than 0"));
const auto place = const auto place =
scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place(); scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place();
...@@ -251,16 +260,6 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -251,16 +260,6 @@ void prefetchs(const std::vector<std::string>& id_var_names,
} }
} }
} }
if (backfill) {
VLOG(3) << "backfill persistable var's id with vecs";
auto* reconstruct_d = reconstruct_var->data<float>();
for (auto& id : ids_union) {
std::copy(recved_vec_map[id].begin(), recved_vec_map[id].end(),
reconstruct_d + id * vec_dim_1);
}
}
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -72,8 +72,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -72,8 +72,9 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(
return framework::OpKernelType(data_type, ctx.device_context()); framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
} }
}; };
...@@ -139,6 +140,10 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -139,6 +140,10 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise the given value indicates padding the output " "Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.") "with zeros whenever lookup encounters it in Ids.")
.SetDefault(distributed::kNoPadding); .SetDefault(distributed::kNoPadding);
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Tablel Prefetch Operator. Lookup Tablel Prefetch Operator.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -69,10 +66,8 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,10 +66,8 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
"with the specified value"); "with the specified value");
AddComment(R"DOC( AddComment(R"DOC(
FakeInit Operator. FakeInit Operator.
Init an variable but not alloc memory for it, it is used for init the Init an variable but not alloc memory for it, it is used for init the
table parameter at trainer side in distributed lookup table. table parameter at trainer side in distributed lookup table.
)DOC"); )DOC");
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace operators {
class RecvSaveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class RecvSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
Recv Save operator
This operator will serialize and write LoDTensor variable to file on disk.
)DOC");
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("overwrite",
"(boolean, default true)"
"Overwrite the output file if exist")
.SetDefault(true);
AddAttr<std::string>("file_path",
"(string)"
"The \"file_path\" where the variable will be saved.")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"slice_varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"remote_varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
AddAttr<std::vector<std::string>>("slice_shapes",
"(vector<int>) "
"the length of each output along the "
"specified axis.")
.SetDefault({});
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
}
};
template <typename DeviceContext, typename T>
class RecvSaveOpKernel : public framework::OpKernel<T> {
private:
void SerializeVersionToStream(std::ostream &os) const {
{ // the 1st field, uint32_t version for LoDTensor
os.write(reinterpret_cast<const char *>(&framework::kCurTensorVersion),
sizeof(framework::kCurTensorVersion));
}
// the 2st field, LoD information
// in this scene, skip LoD information.
uint64_t size = 0;
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
}
void SerializeTensorHeaderToStream(
std::ostream &os, const framework::proto::VarType::Type &type,
const framework::DDim &dims) const {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::proto::VarType::TensorDesc desc;
desc.set_data_type(type);
auto tensor_dims = framework::vectorize(dims);
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(tensor_dims.size()), 0);
std::copy(tensor_dims.begin(), tensor_dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
}
void SerializeTensorAppendToStream(std::ostream &os,
const framework::Tensor &tensor) const {
uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
auto *data_ptr = tensor.data<void>();
PADDLE_ENFORCE_LT(size, std::numeric_limits<std::streamsize>::max(),
platform::errors::ResourceExhausted(
"tensor size %d overflow when writing tensor", size));
os.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW_ERROR(
"%s is existed, cannot save to it when overwrite=false", filename,
overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto origin_shape = ctx.Attr<std::vector<int64_t>>("shape");
auto slice_shapes = ctx.Attr<std::vector<std::string>>("slice_shapes");
auto slice_varnames = ctx.Attr<std::vector<std::string>>("slice_varnames");
auto remote_varnames =
ctx.Attr<std::vector<std::string>>("remote_varnames");
auto endpoints = ctx.Attr<std::vector<std::string>>("endpoints");
PADDLE_ENFORCE_EQ(slice_shapes.size(), slice_varnames.size(),
platform::errors::InvalidArgument(
"Expected attr len(slice_shapes) must be equal to "
"len(slice_varnames)"));
PADDLE_ENFORCE_EQ(
slice_shapes.size(), endpoints.size(),
platform::errors::InvalidArgument(
"Expected attr len(slice_shapes) must be equal to len(endpoints)"));
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout), true,
platform::errors::NotFound("Cannot open %s to write", filename));
SerializeVersionToStream(fout);
SerializeTensorHeaderToStream(fout, data_type,
framework::make_ddim(origin_shape));
framework::Scope &local_scope = ctx.scope().NewScope();
auto trainer_id = ctx.Attr<int>("trainer_id");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &device_ctx = *pool.Get(place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < slice_varnames.size(); i++) {
auto &varname = slice_varnames[i];
auto *var = local_scope.Var(varname);
auto *tensor = var->GetMutable<framework::LoDTensor>();
auto slice_string =
string::split_string<std::string>(slice_shapes[i], ",");
std::vector<int64_t> slice_shape;
for (auto &dim : slice_string) {
slice_shape.push_back(static_cast<int64_t>(std::stoull(dim)));
}
tensor->Resize(framework::make_ddim(slice_shape));
distributed::VarHandlePtr ret;
ret = rpc_client->AsyncGetVarNoBarrier(
endpoints[i], device_ctx, local_scope, remote_varnames[i], varname);
PADDLE_ENFORCE_NE(
ret->Wait(), 0U,
platform::errors::ExecutionTimeout(
"rpc error when communication with %s", endpoints[i]));
auto &c_tensor = var->Get<framework::LoDTensor>();
SerializeTensorAppendToStream(fout, c_tensor);
local_scope.EraseVars({varname});
}
fout.close();
ctx.scope().DeleteScope(&local_scope);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(recv_save, ops::RecvSaveOp, ops::RecvSaveOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
recv_save, ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::RecvSaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -68,46 +68,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -68,46 +68,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch // for remote prefetch
auto remote_prefetch = ctx.Attr<bool>("remote_prefetch");
auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
std::vector<int64_t> real_rows = PathToRows(*path);
framework::Scope& local_scope = ctx.scope().NewScope();
auto* ids = local_scope.Var("Ids@Prefetch");
auto* x_tensor = ids->GetMutable<framework::LoDTensor>();
x_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(real_rows.size()), 1}),
ctx.GetPlace());
// copy.
std::memcpy(x_tensor->data<int64_t>(), real_rows.data(),
real_rows.size() * sizeof(int64_t));
framework::DDim w_dims = ctx.Input<Tensor>("W")->dims();
w_dims[0] = x_tensor->dims()[0];
auto* w_tensor =
local_scope.Var("W@Prefetch")->GetMutable<framework::LoDTensor>();
w_tensor->Resize(w_dims);
#ifdef PADDLE_WITH_DISTRIBUTE
// w_Out is set to used by prefetch, never change it in other cases
auto weight = ctx.OutputNames("W_Out").front();
operators::distributed::prefetch("Ids@Prefetch", "W@Prefetch", weight,
true, table_names, epmap,
height_sections, ctx, local_scope);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
}
bool is_custom = false; bool is_custom = false;
if (path) { if (path) {
is_custom = true; is_custom = true;
......
...@@ -48,8 +48,14 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -48,8 +48,14 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")") R"(Variable will be loaded from "file_path")")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddAttr<int64_t>("seek", "(int64_t) Starting for load tensor from seek pos")
.SetDefault(-1);
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddComment( AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from disk " "Load operator will load a LoDTensor / SelectedRows variable from "
"disk "
"file."); "file.");
} }
}; };
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -63,7 +64,18 @@ class LoadOpKernel : public framework::OpKernel<T> { ...@@ -63,7 +64,18 @@ class LoadOpKernel : public framework::OpKernel<T> {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto seek = ctx.Attr<int64_t>("seek");
if (seek != -1) {
PADDLE_ENFORCE_GE(seek, 0,
platform::errors::InvalidArgument(
"seek witn tensor must great than or equal to 0"));
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
DeserializeFromStream(fin, tensor, dev_ctx, seek, shape);
} else {
DeserializeFromStream(fin, tensor, dev_ctx);
}
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16"); auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto in_dtype = tensor->type(); auto in_dtype = tensor->type();
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <typeindex> #include <typeindex>
#include <vector>
namespace paddle { namespace paddle {
namespace string { namespace string {
......
...@@ -30,7 +30,8 @@ from paddle.reader import * ...@@ -30,7 +30,8 @@ from paddle.reader import *
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, program_guard from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \
program_guard
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from . import reader from . import reader
...@@ -383,72 +384,42 @@ def _save_distributed_persistables(executor, dirname, main_program): ...@@ -383,72 +384,42 @@ def _save_distributed_persistables(executor, dirname, main_program):
# recv optimize vars from pserver # recv optimize vars from pserver
for name, remote_params in remote_params_map.items(): for name, remote_params in remote_params_map.items():
origin_var = None origin = remote_params[0].origin
is_slice = False is_slice = remote_params[0].is_slice
slice_vars = [0] * len(remote_params)
slice_var_names = [""] * len(remote_params) slices = [None] * len(remote_params)
endpoints = [""] * len(remote_params) slice_varnames = [None] * len(remote_params)
remote_varnames = [None] * len(remote_params)
endpoints = [None] * len(remote_params)
for idx, optimizer in enumerate(remote_params): for idx, optimizer in enumerate(remote_params):
origin = optimizer.origin
slice = optimizer.slice
is_slice = optimizer.is_slice
block_id = optimizer.block_id block_id = optimizer.block_id
slice = optimizer.slice
endpoint = optimizer.endpoint endpoint = optimizer.endpoint
if idx == 0:
origin_var = block.create_var(
name=origin.name,
type=origin.type,
shape=origin.shape,
dtype=origin.dtype,
persistable=True)
slice_var = block.create_var(
name="{}.slice.{}".format(slice.name, idx),
type=slice.type,
shape=slice.shape,
dtype=slice.dtype,
persistable=True)
index = block_id if is_slice else idx index = block_id if is_slice else idx
slice_vars[index] = slice_var slices[index] = slice
slice_var_names[index] = slice.name slice_varnames[index] = "{}.slice.{}".format(slice.name, idx)
remote_varnames[index] = slice.name
endpoints[index] = endpoint endpoints[index] = endpoint
if is_slice: slice_shapes = []
block.append_op( for slice in slices:
type='recv', tmp = [str(dim) for dim in slice.shape]
inputs={"X": []}, slice_shapes.append(",".join(tmp))
outputs={"Out": slice_vars},
attrs={
"epmap": endpoints,
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='concat',
inputs={'X': slice_vars},
outputs={'Out': origin_var},
attrs={})
else:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": [origin_var]},
attrs={
"epmap": endpoints[:1],
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op( block.append_op(
type='save', type='recv_save',
inputs={'X': [origin_var]}, attrs={
outputs={}, "trainer_id": 0,
attrs={'file_path': os.path.join(dirname, origin_var.name)}) "shape": origin.shape,
block.append_op(type='delete_var', inputs={'X': slice_vars}) "slice_shapes": slice_shapes,
"slice_varnames": slice_varnames,
"remote_varnames": remote_varnames,
"endpoints": endpoints,
"file_path": os.path.join(dirname, origin.name)
})
executor.run(prog) executor.run(prog)
def __save_distributed_lookup_tables(executor, dirname, def __save_distributed_lookup_tables(executor, dirname,
...@@ -478,8 +449,8 @@ def _save_distributed_persistables(executor, dirname, main_program): ...@@ -478,8 +449,8 @@ def _save_distributed_persistables(executor, dirname, main_program):
if var.name in exclude_var_names: if var.name in exclude_var_names:
return False return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER: var.desc.type() == core.VarDesc.VarType.READER:
return False return False
return var.persistable return var.persistable
...@@ -690,7 +661,7 @@ def load_vars(executor, ...@@ -690,7 +661,7 @@ def load_vars(executor,
if not isinstance(main_program, Program): if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None") raise TypeError("program should be as Program type or None")
#save origin param shape # save origin param shape
orig_para_shape = {} orig_para_shape = {}
load_var_map = {} load_var_map = {}
for each_var in vars: for each_var in vars:
...@@ -725,7 +696,7 @@ def load_vars(executor, ...@@ -725,7 +696,7 @@ def load_vars(executor,
attrs={'file_path': os.path.join(load_dirname, filename)}) attrs={'file_path': os.path.join(load_dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
#check var shape # check var shape
for each_var in vars: for each_var in vars:
if not isinstance(each_var, Parameter): if not isinstance(each_var, Parameter):
continue continue
...@@ -893,21 +864,6 @@ def _load_distributed_persistables(executor, dirname, main_program=None): ...@@ -893,21 +864,6 @@ def _load_distributed_persistables(executor, dirname, main_program=None):
offset = param.offset offset = param.offset
if is_slice: if is_slice:
origin = load_block.create_var(
name="{}.load".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
slice = load_block.create_var( slice = load_block.create_var(
name=slice_var.name, name=slice_var.name,
type=slice_var.type, type=slice_var.type,
...@@ -915,22 +871,15 @@ def _load_distributed_persistables(executor, dirname, main_program=None): ...@@ -915,22 +871,15 @@ def _load_distributed_persistables(executor, dirname, main_program=None):
dtype=slice_var.dtype, dtype=slice_var.dtype,
persistable=True) persistable=True)
dim1_flatten = 1
if len(slice.shape) >= 2:
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0])
load_block.append_op( load_block.append_op(
type="slice", type='load',
inputs={'Input': origin}, inputs={},
outputs={'Out': slice}, outputs={'Out': [slice]},
attrs={'axes': [0], attrs={
'starts': [start], 'file_path': os.path.join(dirname, origin_var.name),
'ends': [end]}) 'seek': offset,
'shape': slice.shape
need_delete_vars.append(origin) })
else: else:
origin = load_block.create_var( origin = load_block.create_var(
name="{}".format(origin_var.name), name="{}".format(origin_var.name),
...@@ -1517,7 +1466,7 @@ def save(program, model_path): ...@@ -1517,7 +1466,7 @@ def save(program, model_path):
base_name = os.path.basename(model_path) base_name = os.path.basename(model_path)
assert base_name != "", \ assert base_name != "", \
"model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str" "model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"
def get_tensor(var): def get_tensor(var):
t = global_scope().find_var(var.name).get_tensor() t = global_scope().find_var(var.name).get_tensor()
...@@ -1574,7 +1523,7 @@ def load(program, model_path, executor=None): ...@@ -1574,7 +1523,7 @@ def load(program, model_path, executor=None):
parameter_file_name = model_path + ".pdparams" parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \ assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format(parameter_file_name) "Parameter file [{}] not exits".format(parameter_file_name)
def set_var(var, ndarray): def set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor() t = global_scope().find_var(var.name).get_tensor()
...@@ -1610,7 +1559,7 @@ def load(program, model_path, executor=None): ...@@ -1610,7 +1559,7 @@ def load(program, model_path, executor=None):
if len(optimizer_var_list) > 0: if len(optimizer_var_list) > 0:
opt_file_name = model_path + ".pdopt" opt_file_name = model_path + ".pdopt"
assert os.path.exists(opt_file_name), \ assert os.path.exists(opt_file_name), \
"Optimizer file [{}] not exits".format( opt_file_name) "Optimizer file [{}] not exits".format(opt_file_name)
if executor: if executor:
paddle.fluid.core._create_loaded_parameter( paddle.fluid.core._create_loaded_parameter(
...@@ -1655,7 +1604,7 @@ def load_program_state(model_path): ...@@ -1655,7 +1604,7 @@ def load_program_state(model_path):
""" """
parameter_file_name = model_path + ".pdparams" parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \ assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format( parameter_file_name) "Parameter file [{}] not exits".format(parameter_file_name)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) para_dict = pickle.load(f)
...@@ -1707,25 +1656,25 @@ def set_program_state(program, state_dict): ...@@ -1707,25 +1656,25 @@ def set_program_state(program, state_dict):
for para in parameter_list: for para in parameter_list:
var_temp = paddle.fluid.global_scope().find_var(para.name) var_temp = paddle.fluid.global_scope().find_var(para.name)
assert var_temp != None, \ assert var_temp != None, \
"Variable [ {} ] Not found, Please make sure run startup program".format( para.name ) "Variable [ {} ] Not found, Please make sure run startup program".format(para.name)
if para.name in state_dict: if para.name in state_dict:
# set value from state dict # set value from state dict
orig_para_np = np.array(var_temp.get_tensor()) orig_para_np = np.array(var_temp.get_tensor())
new_para_np = state_dict[para.name] new_para_np = state_dict[para.name]
assert orig_para_np.shape == new_para_np.shape, \ assert orig_para_np.shape == new_para_np.shape, \
"Shape not matching: the Program requires a parameter with a shape of ({}), " \ "Shape not matching: the Program requires a parameter with a shape of ({}), " \
"while the loaded parameter (namely [ {} ]) has a shape of ({})." \ "while the loaded parameter (namely [ {} ]) has a shape of ({})." \
.format(orig_para_np.shape, para.name, new_para_np.shape) .format(orig_para_np.shape, para.name, new_para_np.shape)
assert orig_para_np.dtype == new_para_np.dtype, \ assert orig_para_np.dtype == new_para_np.dtype, \
"Dtype not matching: the Program requires a parameter with a dtype of ({}), " \ "Dtype not matching: the Program requires a parameter with a dtype of ({}), " \
"while the loaded parameter (namely [ {} ]) has a dtype of ({})." \ "while the loaded parameter (namely [ {} ]) has a dtype of ({})." \
.format(orig_para_np.dtype, para.name, new_para_np.dtype) .format(orig_para_np.dtype, para.name, new_para_np.dtype)
ten = var_temp.get_tensor() ten = var_temp.get_tensor()
ten_place = ten._place() ten_place = ten._place()
assert ten_place.is_gpu_place() or ten_place.is_cpu_place(), \ assert ten_place.is_gpu_place() or ten_place.is_cpu_place(), \
"Place not support, only support CPUPlace and GPUPlace, now is {}".format( str(ten_place)) "Place not support, only support CPUPlace and GPUPlace, now is {}".format(str(ten_place))
py_place = paddle.fluid.CPUPlace() py_place = paddle.fluid.CPUPlace()
if ten_place.is_cuda_pinned_place(): if ten_place.is_cuda_pinned_place():
place = paddle.fluid.CUDAPinnedPlace() place = paddle.fluid.CUDAPinnedPlace()
......
...@@ -15,7 +15,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_dgc_optimizer) ...@@ -15,7 +15,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_dgc_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_simple_dist_transpiler) list(APPEND MIXED_DIST_TEST_OPS test_simple_dist_transpiler)
list(APPEND MIXED_DIST_TEST_OPS test_listen_and_serv_op) list(APPEND MIXED_DIST_TEST_OPS test_listen_and_serv_op)
list(APPEND MIXED_DIST_TEST_OPS test_nce_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_nce_remote_table_op)
list(APPEND MIXED_DIST_TEST_OPS test_hsigmoid_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_recv_save_op)
list(APPEND MIXED_DIST_TEST_OPS test_transpiler_ops)
list(APPEND MIXED_DIST_TEST_OPS test_lookup_remote_table_op) list(APPEND MIXED_DIST_TEST_OPS test_lookup_remote_table_op)
list(APPEND MIXED_DIST_TEST_OPS test_launch) list(APPEND MIXED_DIST_TEST_OPS test_launch)
list(APPEND MIXED_DIST_TEST_OPS test_launch_ps) list(APPEND MIXED_DIST_TEST_OPS test_launch_ps)
...@@ -252,8 +253,9 @@ if(WITH_DISTRIBUTE) ...@@ -252,8 +253,9 @@ if(WITH_DISTRIBUTE)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_base") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_base")
py_test_modules(test_lookup_remote_table_op MODULES test_lookup_remote_table_op ENVS ${dist_ENVS}) py_test_modules(test_lookup_remote_table_op MODULES test_lookup_remote_table_op ENVS ${dist_ENVS})
py_test_modules(test_hsigmoid_remote_table_op MODULES test_hsigmoid_remote_table_op ENVS ${dist_ENVS})
py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS}) py_test_modules(test_nce_remote_table_op MODULES test_nce_remote_table_op ENVS ${dist_ENVS})
py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS})
py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS})
if(WITH_DGC) if(WITH_DGC)
py_test_modules(test_dgc_op MODULES test_dgc_op) py_test_modules(test_dgc_op MODULES test_dgc_op)
py_test_modules(test_dgc_momentum_op MODULES test_dgc_momentum_op) py_test_modules(test_dgc_momentum_op MODULES test_dgc_momentum_op)
......
...@@ -23,6 +23,7 @@ import unittest ...@@ -23,6 +23,7 @@ import unittest
import numpy as np import numpy as np
import gc import gc
gc.set_debug(gc.DEBUG_COLLECTABLE) gc.set_debug(gc.DEBUG_COLLECTABLE)
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -274,6 +275,115 @@ class TestLRDecay(TranspilerTest): ...@@ -274,6 +275,115 @@ class TestLRDecay(TranspilerTest):
]) ])
class TestFakeInit(TranspilerTest):
def net_conf(self):
dict_size, embedding_size, neg_num = 10000, 8, 5
input_word = fluid.layers.data(
name="input_word", shape=[1], dtype='int64', lod_level=1)
true_word = fluid.layers.data(
name='true_label', shape=[1], dtype='int64', lod_level=1)
neg_word = fluid.layers.data(
name="neg_label", shape=[1], dtype='int64', lod_level=1)
inputs = [input_word, true_word, neg_word]
init_width = 0.5 / embedding_size
input_emb = fluid.layers.embedding(
input=inputs[0],
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb',
initializer=fluid.initializer.Uniform(-init_width, init_width)))
true_emb_w = fluid.layers.embedding(
input=inputs[1],
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w',
initializer=fluid.initializer.Constant(value=0.0)))
true_emb_b = fluid.layers.embedding(
input=inputs[1],
is_sparse=True,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b',
initializer=fluid.initializer.Constant(value=0.0)))
neg_word_reshape = fluid.layers.reshape(inputs[2], shape=[-1, 1])
neg_word_reshape.stop_gradient = True
neg_emb_w = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w', learning_rate=1.0))
neg_emb_w_re = fluid.layers.reshape(
neg_emb_w, shape=[-1, neg_num, embedding_size])
neg_emb_b = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=True,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b', learning_rate=1.0))
neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(input_emb, true_emb_w),
dim=1,
keep_dim=True),
true_emb_b)
input_emb_re = fluid.layers.reshape(
input_emb, shape=[-1, 1, embedding_size])
neg_matmul = fluid.layers.matmul(
input_emb_re, neg_emb_w_re, transpose_y=True)
neg_matmul_re = fluid.layers.reshape(neg_matmul, shape=[-1, neg_num])
neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec)
# nce loss
label_ones = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, 1], value=1.0, dtype='float32')
label_zeros = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')
true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits,
label_ones)
neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits,
label_zeros)
cost = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
true_xent, dim=1),
fluid.layers.reduce_sum(
neg_xent, dim=1))
avg_cost = fluid.layers.reduce_mean(cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=1.0,
decay_steps=2100,
decay_rate=0.1,
staircase=True))
sgd_optimizer.minimize(avg_cost)
def transpiler_test_impl(self):
trainer, startup = self.get_trainer()
fake_init_ops = []
for op in startup.global_block().ops:
if op.type == "fake_init":
fake_init_ops.append(op)
self.assertEqual(len(fake_init_ops), 3)
class TestDecayedAdagrad(TranspilerTest): class TestDecayedAdagrad(TranspilerTest):
def net_conf(self): def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32') x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
...@@ -788,7 +898,7 @@ class TestLoadSliceVar(TranspilerTest): ...@@ -788,7 +898,7 @@ class TestLoadSliceVar(TranspilerTest):
class TestNCCL2Transpile(TranspilerTest): class TestNCCL2Transpile(TranspilerTest):
def test_nccl2_transpile(self): def test_nccl2_transpile(self):
if fluid.core.is_compiled_with_cuda(): #test nccl2 only with cuda if fluid.core.is_compiled_with_cuda(): # test nccl2 only with cuda
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
......
...@@ -17,7 +17,9 @@ from __future__ import print_function ...@@ -17,7 +17,9 @@ from __future__ import print_function
import os import os
import signal import signal
import time import time
import shutil
import unittest import unittest
from multiprocessing import Process from multiprocessing import Process
import numpy as np import numpy as np
...@@ -25,17 +27,18 @@ import paddle.fluid as fluid ...@@ -25,17 +27,18 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler.details import VarStruct, VarsDistributed
from dist_test_utils import * from dist_test_utils import *
def run_pserver(pserver_id, use_cuda, sync_mode): def run_pserver(pserver_id):
remove_ps_flag(os.getpid()) remove_ps_flag(os.getpid())
scope = fluid.core.Scope() scope = fluid.core.Scope()
program = Program() program = Program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()): with program_guard(program, startup_program=Program()):
# create table parameter in scope # create table parameter in scope
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CPUPlace()
# create and initialize Param Variable # create and initialize Param Variable
param = scope.var('table').get_tensor() param = scope.var('table').get_tensor()
...@@ -65,8 +68,8 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -65,8 +68,8 @@ class TestListenAndServOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.ps_timeout = 5 self.ps_timeout = 5
def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func): def _start_pserver(self, pserver_id, pserver_func):
p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode)) p = Process(target=pserver_func, args=(pserver_id, ))
p.daemon = True p.daemon = True
p.start() p.start()
return p return p
...@@ -90,175 +93,163 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -90,175 +93,163 @@ class TestListenAndServOp(unittest.TestCase):
port = int(f.read().strip()) port = int(f.read().strip())
return port return port
def _run_hsigmoid_op_one_pserver(self, place, port): def _run_nce_op_two_pserver(self, place, port0, port1, model_file):
scope = fluid.core.Scope() scope = fluid.core.Scope()
program = Program() program = Program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()): with program_guard(program, startup_program=Program()):
x = scope.var('X').get_tensor() emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
x_array = np.random.random((4, 8)).astype("float32") * 2
x.set(x_array, place) # create and run recv and save operator
# create and initialize Param Variable remote_recv_op = Operator(
param = scope.var('W').get_tensor() "recv_save",
param_array = np.zeros((5, 8)).astype("float32") * 2 trainer_id=0,
param.set(param_array, place) shape=[10, 8],
slice_shapes=["5,8", "5,8"],
path_table = scope.var('PathTable').get_tensor() slice_varnames=["table", "table"],
path_table_array = np.array( remote_varnames=['table', 'table'],
[(0, 2, -1, -1, -1), (0, 1, 2, -1, -1), (0, 1, 4, -1, -1), endpoints=emaps,
(0, 2, -1, -1, -1)]).astype( file_path=model_file)
"int64"
) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf) remote_recv_op.run(scope, place)
path_table.set(path_table_array, place)
def _load_slice_var(self, model_file):
path_code = scope.var('PathCode').get_tensor() load_prog = fluid.Program()
path_code_array = np.array( load_block = load_prog.global_block()
[(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (1, 0, 0, -1, -1),
(0, 1, -1, -1, -1)]).astype("int64") #np.array to store origin = load_block.create_var(
path_code.set(path_code_array, place) name="var.origin",
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
label = scope.var('Label').get_tensor() shape=[10, 8],
label_array = np.array([0, 1, 4, 5]) dtype="float32",
label.set(label_array, place) persistable=True)
bias = scope.var('Bias').get_tensor() slice0 = load_block.create_var(
bias_array = np.random.random((5, 1)).astype("float32") name="var.slice0",
bias.set(bias_array, place) type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[3, 8],
out = scope.var('Out').get_tensor() dtype="float32",
persistable=True)
pre_out = scope.var('PreOut').get_tensor
slice1 = load_block.create_var(
w_out = scope.var('W_Out').get_tensor() name="var.slice1",
w_out.set(param_array, place) type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[5, 8],
emaps = ['127.0.0.1:' + str(port)] dtype="float32",
table_names = ['table'] persistable=True)
height_sections = [2]
load_block.append_op(
# create and run sgd operator type='load',
hsigmoid_op = Operator( inputs={},
"hierarchical_sigmoid", outputs={'Out': [origin]},
X='X', attrs={'file_path': model_file})
W='W',
PathTable='PathTable', load_block.append_op(
PathCode='PathCode', type='load',
Label='Label', inputs={},
Bias='Bias', outputs={'Out': [slice0]},
Out='Out', attrs={
PreOut='PreOut', 'file_path': model_file,
W_Out='W_Out', 'seek': 2 * 8,
remote_prefetch=True, 'shape': slice0.shape
epmap=emaps, })
table_names=table_names,
height_sections=height_sections) load_block.append_op(
type='load',
hsigmoid_op.run(scope, place) inputs={},
outputs={'Out': [slice1]},
# get and compare result attrs={
result_array = np.array(w_out) 'file_path': model_file,
self.assertEqual(list(result_array.shape), [5, 8]) 'seek': 5 * 8,
correct = None 'shape': slice1.shape
for i in range(5): })
if i != 3:
correct = np.full((1, 8), i + 1).astype("float32") exe = fluid.Executor(place=fluid.CPUPlace())
self.assertTrue((result_array[i] == correct).all()) exe.run(load_prog)
else:
correct = np.full((1, 8), 0).astype("float32") origin_var = fluid.global_scope().find_var("var.origin")
self.assertTrue((result_array[i] == correct).all()) slice0_var = fluid.global_scope().find_var("var.slice0")
slice1_var = fluid.global_scope().find_var("var.slice1")
def _run_hsigmoid_op_two_pserver(self, place, port0, port1):
scope = fluid.core.Scope() origin = np.array(origin_var.get_tensor())
slice0 = np.array(slice0_var.get_tensor())
slice1 = np.array(slice1_var.get_tensor())
np.testing.assert_equal(origin[2:5], slice0)
np.testing.assert_equal(origin[5:10], slice1)
def _save_by_io_persistables(self, place, port0, port1, dirname, var_name):
exe = fluid.Executor(place=place)
vars_overview = VarsDistributed()
orig_var = VarStruct(
name=var_name,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[10, 8],
dtype="float32",
lod_level=0,
persistable=True)
slice_0_var = VarStruct(
name=var_name,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[5, 8],
dtype="float32",
lod_level=0,
persistable=True)
slice_1_var = VarStruct(
name=var_name,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[5, 8],
dtype="float32",
lod_level=0,
persistable=True)
vars_overview.add_distributed_var(
origin_var=orig_var,
slice_var=slice_0_var,
block_id=0,
offset=0,
is_slice=True,
vtype="RemotePrefetch",
endpoint="{}:{}".format("127.0.0.1", port0))
vars_overview.add_distributed_var(
origin_var=orig_var,
slice_var=slice_1_var,
block_id=1,
offset=40,
is_slice=True,
vtype="RemotePrefetch",
endpoint="{}:{}".format("127.0.0.1", port1))
program = Program() program = Program()
with fluid.scope_guard(scope): program._is_distributed = True
with program_guard(program, startup_program=Program()): program._is_chief = True
x = scope.var('X').get_tensor() program._parameters_on_pservers = vars_overview
x_array = np.random.random((4, 8)).astype("float32") * 2
x.set(x_array, place)
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.zeros((5, 8)).astype("float32") * 2
param.set(param_array, place)
path_table = scope.var('PathTable').get_tensor()
path_table_array = np.array(
[(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
(0, 2, -1, -1, -1)]).astype(
"int64"
) #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
path_table.set(path_table_array, place)
path_code = scope.var('PathCode').get_tensor()
path_code_array = np.array(
[(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (1, 0, 0, -1, -1),
(0, 1, -1, -1, -1)]).astype("int64") #np.array to store
path_code.set(path_code_array, place)
label = scope.var('Label').get_tensor()
label_array = np.array([0, 1, 4, 5])
label.set(label_array, place)
bias = scope.var('Bias').get_tensor()
bias_array = np.random.random((5, 1)).astype("float32")
bias.set(bias_array, place)
out = scope.var('Out').get_tensor()
pre_out = scope.var('PreOut').get_tensor
w_out = scope.var('W_Out').get_tensor()
w_out.set(param_array, place)
emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)] fluid.io.save_persistables(exe, dirname, program)
table_names = ['table', 'table']
height_sections = [2, 3] def test_recv_save_op_remote(self):
# create and run sgd operator
hsigmoid_op = Operator(
"hierarchical_sigmoid",
X='X',
W='W',
PathTable='PathTable',
PathCode='PathCode',
Label='Label',
Bias='Bias',
Out='Out',
PreOut='PreOut',
W_Out='W_Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
hsigmoid_op.run(scope, place)
# get and compare result
result_array = np.array(w_out)
self.assertEqual(list(result_array.shape), [5, 8])
correct = None
for i in range(5):
if i < 2:
correct = np.full((1, 8), i + 1).astype("float32")
self.assertTrue((result_array[i] == correct).all())
else:
correct = np.full((1, 8), i + 9).astype("float32")
self.assertTrue((result_array[i] == correct).all())
def test_hsigmoid_op_remote(self):
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
# run pserver on CPU in sync mode # run pserver on CPU in sync mode
p0 = self._start_pserver(0, False, True, run_pserver) p0 = self._start_pserver(0, run_pserver)
self._wait_ps_ready(p0.pid) self._wait_ps_ready(p0.pid)
port0 = self._get_pserver_port(p0.pid) port0 = self._get_pserver_port(p0.pid)
p1 = self._start_pserver(1, False, True, run_pserver) p1 = self._start_pserver(1, run_pserver)
self._wait_ps_ready(p1.pid) self._wait_ps_ready(p1.pid)
port1 = self._get_pserver_port(p1.pid) port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()] places = [core.CPUPlace()]
param_dir = "./model_for_test_recv_save_op/"
param_name = "table"
for place in places: for place in places:
self._run_hsigmoid_op_one_pserver(place, port0) self._save_by_io_persistables(place, port0, port1, param_dir,
self._run_hsigmoid_op_two_pserver(place, port0, port1) param_name)
# raise SIGTERM to pserver # raise SIGTERM to pserver
os.kill(p0.pid, signal.SIGINT) os.kill(p0.pid, signal.SIGINT)
...@@ -266,6 +257,9 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -266,6 +257,9 @@ class TestListenAndServOp(unittest.TestCase):
os.kill(p1.pid, signal.SIGINT) os.kill(p1.pid, signal.SIGINT)
p1.join() p1.join()
self._load_slice_var(param_dir + param_name)
shutil.rmtree(param_dir)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import traceback
import math
import collections
import six
import unittest
import numpy as np
import gc
gc.set_debug(gc.DEBUG_COLLECTABLE)
import paddle.fluid as fluid
from test_dist_transpiler import TranspilerTest
class TestFakeInit(TranspilerTest):
def net_conf(self):
dict_size, embedding_size, neg_num = 10000, 8, 5
input_word = fluid.layers.data(
name="input_word", shape=[1], dtype='int64', lod_level=1)
true_word = fluid.layers.data(
name='true_label', shape=[1], dtype='int64', lod_level=1)
neg_word = fluid.layers.data(
name="neg_label", shape=[1], dtype='int64', lod_level=1)
inputs = [input_word, true_word, neg_word]
init_width = 0.5 / embedding_size
input_emb = fluid.layers.embedding(
input=inputs[0],
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb',
initializer=fluid.initializer.Uniform(-init_width, init_width)))
true_emb_w = fluid.layers.embedding(
input=inputs[1],
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w',
initializer=fluid.initializer.Constant(value=0.0)))
true_emb_b = fluid.layers.embedding(
input=inputs[1],
is_sparse=True,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b',
initializer=fluid.initializer.Constant(value=0.0)))
neg_word_reshape = fluid.layers.reshape(inputs[2], shape=[-1, 1])
neg_word_reshape.stop_gradient = True
neg_emb_w = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=True,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w', learning_rate=1.0))
neg_emb_w_re = fluid.layers.reshape(
neg_emb_w, shape=[-1, neg_num, embedding_size])
neg_emb_b = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=True,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b', learning_rate=1.0))
neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(input_emb, true_emb_w),
dim=1,
keep_dim=True),
true_emb_b)
input_emb_re = fluid.layers.reshape(
input_emb, shape=[-1, 1, embedding_size])
neg_matmul = fluid.layers.matmul(
input_emb_re, neg_emb_w_re, transpose_y=True)
neg_matmul_re = fluid.layers.reshape(neg_matmul, shape=[-1, neg_num])
neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec)
# nce loss
label_ones = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, 1], value=1.0, dtype='float32')
label_zeros = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')
true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits,
label_ones)
neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits,
label_zeros)
cost = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
true_xent, dim=1),
fluid.layers.reduce_sum(
neg_xent, dim=1))
avg_cost = fluid.layers.reduce_mean(cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=1.0,
decay_steps=2100,
decay_rate=0.1,
staircase=True))
sgd_optimizer.minimize(avg_cost)
def transpiler_test_impl(self):
trainer, startup = self.get_trainer()
fake_init_ops = []
for op in startup.global_block().ops:
if op.type == "fake_init":
fake_init_ops.append(op)
self.assertEqual(len(fake_init_ops), 3)
if __name__ == "__main__":
unittest.main()
...@@ -403,7 +403,7 @@ class DistributeTranspiler(object): ...@@ -403,7 +403,7 @@ class DistributeTranspiler(object):
def _get_all_remote_sparse_update_op(self, main_program): def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = [] sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] sparse_update_op_types = ["lookup_table", "nce"]
for op in main_program.global_block().ops: for op in main_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr( if op.type in sparse_update_op_types and op.attr(
'remote_prefetch') is True: 'remote_prefetch') is True:
...@@ -607,6 +607,7 @@ class DistributeTranspiler(object): ...@@ -607,6 +607,7 @@ class DistributeTranspiler(object):
self.origin_program) self.origin_program)
# use_sparse_update_param_name -> split_height_section # use_sparse_update_param_name -> split_height_section
self.sparse_param_to_height_sections = dict() self.sparse_param_to_height_sections = dict()
self.need_delete_optimize_vars = []
# add distributed attrs to program # add distributed attrs to program
self.origin_program._is_distributed = True self.origin_program._is_distributed = True
...@@ -861,6 +862,78 @@ class DistributeTranspiler(object): ...@@ -861,6 +862,78 @@ class DistributeTranspiler(object):
self._get_distributed_optimizer_vars() self._get_distributed_optimizer_vars()
self.origin_program._parameters_on_pservers = self.vars_overview self.origin_program._parameters_on_pservers = self.vars_overview
def _get_sparse_table_names(self):
sparse_update_op_types = ["lookup_table", "nce"]
sparse_table_names = []
for op in self.origin_program.global_block().ops:
if op.type in sparse_update_op_types and op.attr(
'is_sparse') is True:
sparse_table_names.append(op.input("W")[0])
if op.type == "distributed_lookup_table":
sparse_table_names.append(op.input("W")[0])
if self.has_distributed_lookup_table:
sparse_table_names.append(self.table_name)
return list(set(sparse_table_names))
def _fake_init_sparsetable(self, sparse_table_names):
# delete table init op
for table_name in sparse_table_names:
table_var = self.startup_program.global_block().vars[table_name]
table_param_init_op = []
for op in self.startup_program.global_block().ops:
if table_name in op.output_arg_names:
table_param_init_op.append(op)
init_op_num = len(table_param_init_op)
if init_op_num != 1:
raise ValueError("table init op num should be 1, now is " + str(
init_op_num))
table_init_op = table_param_init_op[0]
self.startup_program.global_block().append_op(
type="fake_init",
inputs={},
outputs={"Out": table_var},
attrs={"shape": table_init_op.attr('shape')})
delete_ops(self.startup_program.global_block(), table_param_init_op)
def _delete_trainer_optimizer(self, is_startup):
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in self.optimize_ops:
optimize_vars.extend(op.input_arg_names)
optimize_op_role_vars.extend(op.attr("op_role_var"))
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
if is_startup:
init_ops = []
for var in need_delete_optimize_vars:
param_init_op = []
for op in self.startup_program.global_block().ops:
if var in op.output_arg_names:
param_init_op.append(op)
init_ops.extend(param_init_op)
delete_ops(self.startup_program.global_block(), init_ops)
for var in need_delete_optimize_vars:
if self.startup_program.global_block().has_var(var):
self.startup_program.global_block()._remove_var(var)
else:
delete_ops(self.origin_program.global_block(), self.optimize_ops)
for var in need_delete_optimize_vars:
if self.origin_program.global_block().has_var(var):
self.origin_program.global_block()._remove_var(var)
def get_trainer_program(self, wait_port=True): def get_trainer_program(self, wait_port=True):
""" """
Get transpiled trainer side program. The program on trainer side compared with origin program Get transpiled trainer side program. The program on trainer side compared with origin program
...@@ -891,31 +964,16 @@ class DistributeTranspiler(object): ...@@ -891,31 +964,16 @@ class DistributeTranspiler(object):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
# FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay?
self._delete_trainer_optimizer(is_startup=True)
sparse_table_names = self._get_sparse_table_names()
self._fake_init_sparsetable(sparse_table_names)
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
delete_ops(self.origin_program.global_block(), self.optimize_ops)
delete_ops(self.origin_program.global_block(), lr_ops) delete_ops(self.origin_program.global_block(), lr_ops)
self._delete_trainer_optimizer(is_startup=False)
# delete table init op
if self.has_distributed_lookup_table:
table_var = self.startup_program.global_block().vars[
self.table_name]
table_param_init_op = []
for op in self.startup_program.global_block().ops:
if self.table_name in op.output_arg_names:
table_param_init_op.append(op)
init_op_num = len(table_param_init_op)
if init_op_num != 1:
raise ValueError("table init op num should be 1, now is " + str(
init_op_num))
table_init_op = table_param_init_op[0]
self.startup_program.global_block().append_op(
type="fake_init",
inputs={},
outputs={"Out": table_var},
attrs={"shape": table_init_op.attr('shape')})
delete_ops(self.startup_program.global_block(), table_param_init_op)
self.origin_program.__str__() self.origin_program.__str__()
self.startup_program.__str__()
if wait_port: if wait_port:
wait_server_ready(self.pserver_endpoints) wait_server_ready(self.pserver_endpoints)
...@@ -937,8 +995,14 @@ class DistributeTranspiler(object): ...@@ -937,8 +995,14 @@ class DistributeTranspiler(object):
# FIXME(gongwb): delete not need ops. # FIXME(gongwb): delete not need ops.
# note that: some parameter is not trainable and those ops can't be deleted. # note that: some parameter is not trainable and those ops can't be deleted.
sparse_table_names = self._get_sparse_table_names()
# self._fake_init_sparsetable(sparse_table_names)
#self._delete_trainer_optimizer(is_startup=True)
for varname, splited_var in six.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
if varname in sparse_table_names:
continue
# Get the eplist of recv vars # Get the eplist of recv vars
eps = [] eps = []
for var in splited_var: for var in splited_var:
...@@ -980,6 +1044,8 @@ class DistributeTranspiler(object): ...@@ -980,6 +1044,8 @@ class DistributeTranspiler(object):
}) })
for varname, splited_var in six.iteritems(self.param_var_mapping): for varname, splited_var in six.iteritems(self.param_var_mapping):
if varname in sparse_table_names:
continue
# add concat ops to merge splited parameters received from parameter servers. # add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册