提交 c00e07cd 编写于 作者: Y Yu Yang

Fix distribute compile

test=develop
上级 81520a24
...@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() { ...@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
} }
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if (framework::IsType<const float>(in_selected_rows[0]->value().type())) { if (in_selected_rows[0]->value().type() ==
framework::proto::VarType::FP32) {
GatherSelectedRows<platform::CUDADeviceContext, float>( GatherSelectedRows<platform::CUDADeviceContext, float>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p, in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else if (framework::IsType<const double>( } else if (in_selected_rows[0]->value().type() ==
in_selected_rows[0]->value().type())) { framework::proto::VarType::FP64) {
GatherSelectedRows<platform::CUDADeviceContext, double>( GatherSelectedRows<platform::CUDADeviceContext, double>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p, in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else { } else {
PADDLE_ENFORCE(false, PADDLE_THROW("only support double or float when gather SelectedRows");
"only support double or float when gahter SelectedRows");
} }
#endif #endif
}); });
......
...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size = size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var, ...@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from // FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto // framework.proto
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(tensor.type()));
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim); request->add_dims(dim);
} }
...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request) { VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type( request->set_data_type(static_cast<VarMsg::Type>(slr->value().type()));
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0); request->set_lod_level(0);
request->set_slr_height(slr->height()); request->set_slr_height(slr->height());
......
...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, ...@@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
VarMsg* request); VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline framework::proto::VarType::Type ToVarType(
sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
case sendrecv::VariableMessage::FP32: case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT return framework::proto::VarType::FP32; // NOLINT
case sendrecv::VariableMessage::FP64: case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT return framework::proto::VarType::FP64; // NOLINT
case sendrecv::VariableMessage::INT32: case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT return framework::proto::VarType::INT32; // NOLINT
case sendrecv::VariableMessage::INT64: case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT return framework::proto::VarType::INT64; // NOLINT
case sendrecv::VariableMessage::BOOL: case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
PADDLE_THROW("Not support type %d", type); PADDLE_THROW("Not support type %d", type);
} }
......
...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor->set_lod(lod); tensor->set_lod(lod);
void* tensor_data = void* tensor_data =
tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData( ...@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()), PADDLE_ENFORCE_EQ(
length / framework::SizeOfType( static_cast<size_t>(tensor->numel()),
paddle::operators::distributed::ToTypeIndex( length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
meta_.data_type()))); meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
ctx.GetPlace(), ctx.GetPlace(),
paddle::operators::distributed::ToTypeIndex(meta_.data_type())); paddle::operators::distributed::ToVarType(meta_.data_type()));
if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
return false; return false;
...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData( ...@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) { const platform::DeviceContext& ctx, int length) {
auto* slr = GetVar()->GetMutable<framework::SelectedRows>(); auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear(); slr->mutable_rows()->clear();
slr->mutable_rows()->resize(length / slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data(); int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily. // copy rows CPU data, GPU data will be copied lazily.
......
...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
} }
}; };
......
...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册