未验证 提交 b849157e 编写于 作者: G gongweibao 提交者: GitHub

Add size enforce (#14919)

上级 aa6e9c30
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <sys/time.h> #include <sys/time.h>
#include <limits>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -31,7 +32,12 @@ namespace distributed { ...@@ -31,7 +32,12 @@ namespace distributed {
class IOBufWriter { class IOBufWriter {
public: public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) { static void Append(const std::string& varname, butil::IOBuf* iobuf, int k,
const char* v, int64_t vlen) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
}
iobuf->append(reinterpret_cast<char*>(&k), 4); iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8); iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen); iobuf->append(v, vlen);
...@@ -87,6 +93,10 @@ class IOBufWriter { ...@@ -87,6 +93,10 @@ class IOBufWriter {
int k, const char* v, int64_t vlen, int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*), bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) { void* user_data) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
}
#ifdef PADDLE_WITH_BRPC_RDMA #ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned, IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data); destroy, user_data);
...@@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
request->set_type(::sendrecv::NCCL_ID); request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>(); const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy. // TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf, IOBufWriter::Append(name, iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber, sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES); uid.internal, NCCL_UNIQUE_ID_BYTES);
return; return;
...@@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
// FIXME(gongwb): it seems that can use zero copy. // FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) { if (var_is_not_stable) {
IOBufWriter::Append( IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size()); static_cast<const char*>(payload->ptr()), payload->memory_size());
} else { } else {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
...@@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size = PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
slr->rows().size() * framework::SizeOfType(typeid(int64_t)); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber, IOBufWriter::Append(name, iobuf,
::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()), reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size)); static_cast<int64_t>(rows_memory_size));
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <limits>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
...@@ -102,6 +103,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -102,6 +103,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size()); payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) {
LOG(FATAL) << "AppendZeroCopy varname:" << name
<< ", vlen:" << payload->memory_size();
}
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer int num_slices = 2; // only SelectedRows have rows buffer
...@@ -115,7 +120,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -115,7 +120,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -23,9 +24,8 @@ limitations under the License. */ ...@@ -23,9 +24,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -83,6 +83,11 @@ inline framework::proto::VarType::Type ToVarType( ...@@ -83,6 +83,11 @@ inline framework::proto::VarType::Type ToVarType(
} }
} }
template <template <typename> class T, typename Elem>
std::string VectorElemName(const T<Elem>& arg) {
return typeid(Elem).name();
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -118,7 +118,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -118,7 +118,7 @@ bool VariableResponse::CopyLodTensorData(
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
PADDLE_ENFORCE_EQ(tensor->memory_size(), length); PADDLE_ENFORCE_EQ(tensor->memory_size(), static_cast<unsigned int>(length));
return ReadRaw(input, ctx, tensor->place(), tensor_data, length); return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册