提交 c00e07cd 编写于 作者: Y Yu Yang

Fix distribute compile

test=develop
上级 81520a24
......@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if (framework::IsType<const float>(in_selected_rows[0]->value().type())) {
if (in_selected_rows[0]->value().type() ==
framework::proto::VarType::FP32) {
GatherSelectedRows<platform::CUDADeviceContext, float>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
} else if (framework::IsType<const double>(
in_selected_rows[0]->value().type())) {
} else if (in_selected_rows[0]->value().type() ==
framework::proto::VarType::FP64) {
GatherSelectedRows<platform::CUDADeviceContext, double>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
} else {
PADDLE_ENFORCE(false,
"only support double or float when gahter SelectedRows");
PADDLE_THROW("only support double or float when gather SelectedRows");
}
#endif
});
......
......@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
......@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim);
}
......@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
request->set_lod_level(0);
request->set_slr_height(slr->height());
......
......@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) {
case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT
return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT
return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT
return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT
return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW("Not support type %d", type);
}
......
......@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length;
......@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(
paddle::operators::distributed::ToTypeIndex(
meta_.data_type())));
PADDLE_ENFORCE_EQ(
static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false;
......@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
......
......@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
}
};
......
......@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册