提交 c1c5e166 编写于 作者: Y Yi Wang

Fix cpplint errors

上级 64242c5d
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC // file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data. // requests without too much copying of the tensor data.
#include "bytebuffer_stream.h" #include "paddle/fluid/operators/detail/bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -19,9 +19,11 @@ limitations under the License. */ ...@@ -19,9 +19,11 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h> #include <vector>
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h"
namespace grpc { namespace grpc {
// A ZeroCopyInputStream that reads from grpc_byte_buffer // A ZeroCopyInputStream that reads from grpc_byte_buffer
...@@ -56,7 +58,7 @@ class GrpcBufferReader final ...@@ -56,7 +58,7 @@ class GrpcBufferReader final
*data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) - *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) -
backup_count_; backup_count_;
GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX); GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX);
*size = (int)backup_count_; *size = static_cast<int>(backup_count_);
backup_count_ = 0; backup_count_ = 0;
return true; return true;
} }
...@@ -68,7 +70,7 @@ class GrpcBufferReader final ...@@ -68,7 +70,7 @@ class GrpcBufferReader final
*data = GRPC_SLICE_START_PTR(slice_); *data = GRPC_SLICE_START_PTR(slice_);
// On win x64, int is only 32bit // On win x64, int is only 32bit
GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX);
byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_); byte_count_ += * size = static_cast<int>(GRPC_SLICE_LENGTH(slice_));
return true; return true;
} }
......
...@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include <sys/time.h> #include <sys/time.h>
#include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
...@@ -52,7 +56,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -52,7 +56,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -64,7 +68,7 @@ void ProcGetResponse(const VarHandle& var_h, ...@@ -64,7 +68,7 @@ void ProcGetResponse(const VarHandle& var_h,
// const sendrecv::VariableMessage& ret_msg) { // const sendrecv::VariableMessage& ret_msg) {
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = NULL; framework::Variable* outvar = NULL;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, outvar); DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
} }
template <typename T> template <typename T>
...@@ -109,7 +113,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -109,7 +113,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -126,7 +130,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -126,7 +130,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
} }
...@@ -138,7 +142,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -138,7 +142,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
} }
......
...@@ -14,10 +14,9 @@ limitations under the License. */ ...@@ -14,10 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <time.h> #include <time.h>
#include <chrono>
#include <chrono> // NOLINT
#include <ctime> #include <ctime>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
...@@ -25,11 +24,11 @@ limitations under the License. */ ...@@ -25,11 +24,11 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include <grpc++/generic/generic_stub.h> #include "grpc++/generic/generic_stub.h"
#include <grpc++/grpc++.h> #include "grpc++/grpc++.h"
#include <grpc++/support/byte_buffer.h> #include "grpc++/support/byte_buffer.h"
#include <grpc++/support/slice.h> #include "grpc++/support/slice.h"
#include "grpc/support/log.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
namespace paddle { namespace paddle {
...@@ -205,7 +208,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -205,7 +208,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
// FIXME(typhoonzero): change cq_name to enum. // FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
std::string cq_name, const std::string& cq_name,
std::function<void()> TryToRegisterNewOne) { std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne(); TryToRegisterNewOne();
...@@ -222,7 +225,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -222,7 +225,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
if (cq_name == "cq_get") WaitCond(1); if (cq_name == "cq_get") WaitCond(1);
if (cq_name == "cq_send") WaitCond(0); if (cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag; RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
......
...@@ -14,9 +14,11 @@ limitations under the License. */ ...@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h> #include <string>
#include <thread> #include <thread> // NOLINT
#include <utility>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
...@@ -62,7 +64,8 @@ class AsyncGRPCServer final { ...@@ -62,7 +64,8 @@ class AsyncGRPCServer final {
void ShutDown(); void ShutDown();
protected: protected:
void HandleRequest(::grpc::ServerCompletionQueue *cq, std::string cq_name, void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void()> TryToRegisterNewOne); std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne(); void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne(); void TryToRegisterNewGetOne();
......
...@@ -114,5 +114,5 @@ class GrpcService final { ...@@ -114,5 +114,5 @@ class GrpcService final {
}; };
} // namespace detail } // namespace detail
} // namespace operator } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -19,7 +19,9 @@ limitations under the License. */ ...@@ -19,7 +19,9 @@ limitations under the License. */
#pragma once #pragma once
#include <grpc++/grpc++.h> #include <string>
#include "grpc++/grpc++.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -142,6 +144,6 @@ class ProtoEncodeHelper { ...@@ -142,6 +144,6 @@ class ProtoEncodeHelper {
char* limit_; // Just for CHECKs char* limit_; // Just for CHECKs
}; };
} // detail } // namespace detail
} // operators } // namespace operators
} // paddle } // namespace paddle
...@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include <sys/time.h> #include <sys/time.h>
#include <thread> #include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -42,7 +44,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -42,7 +44,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = malloc(1024); void* buf = malloc(1024);
void* payload = nullptr; void* payload = nullptr;
size_t payload_size; size_t payload_size;
ProtoEncodeHelper e((char*)buf, 1024); ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteString(VarMsg::kVarnameFieldNumber, name); e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0); e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
...@@ -152,7 +154,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -152,7 +154,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
framework::proto::VarType_Type_SELECTED_ROWS) { framework::proto::VarType_Type_SELECTED_ROWS) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2((char*)buf, 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
// NOTE: rows is of type int64_t // NOTE: rows is of type int64_t
size_t rows_memory_size = size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t)); slr->rows().size() * framework::SizeOfType(typeid(int64_t));
...@@ -181,10 +183,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -181,10 +183,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable*& var) { framework::Variable** var) {
operators::detail::VariableResponse resp(scope, &ctx); operators::detail::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
var = resp.GetVar(); *var = resp.GetVar();
} }
} // namespace detail } // namespace detail
......
...@@ -51,7 +51,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -51,7 +51,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable*& var); framework::Variable** var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include <string> #include <string>
#include <thread> #include <thread> // NOLINT
#include <google/protobuf/text_format.h> #include "google/protobuf/text_format.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -102,12 +102,12 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -102,12 +102,12 @@ void RunSerdeTestSelectedRows(platform::Place place) {
} else { } else {
tensor_data2 = const_cast<float*>(tensor2->data<float>()); tensor_data2 = const_cast<float*>(tensor2->data<float>());
} }
const int64_t* rows_data2 = rows2->data(); const size_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) { for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
} }
for (int i = 0; i < rows2->size(); ++i) { for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i); EXPECT_EQ(rows_data2[i], i);
} }
EXPECT_EQ(slr2->height(), 1000); EXPECT_EQ(slr2->height(), 1000);
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <condition_variable> #include <condition_variable> // NOLINT
#include <deque> #include <deque>
#include <mutex> #include <mutex> // NOLINT
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h" #include "paddle/fluid/operators/detail/variable_response.h"
#include <string.h> #include <string.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...@@ -108,7 +114,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -108,7 +114,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
bool VariableResponse::CopyLodTensorData( bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) { const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims); tensor->Resize(dims);
...@@ -144,14 +151,15 @@ inline framework::DDim GetDims( ...@@ -144,14 +151,15 @@ inline framework::DDim GetDims(
bool VariableResponse::CopySelectRowsTensorData( bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) { const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
tensor->Resize(dims); tensor->Resize(dims);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor->numel(), static_cast<size_t>(tensor->numel()),
length / framework::SizeOfType( length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type()))); paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data( void* tensor_data = tensor->mutable_data(
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -60,14 +62,14 @@ class VariableResponse { ...@@ -60,14 +62,14 @@ class VariableResponse {
private: private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
framework::DDim& dims, int length); const framework::DDim& dims, int length);
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input, bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length); const platform::DeviceContext& ctx, int length);
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input, bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
framework::DDim& dims, int length); const framework::DDim& dims, int length);
private: private:
const framework::Scope* scope_; const framework::Scope* scope_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册