Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c1c5e166
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1c5e166
编写于
3月 29, 2018
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix cpplint errors
上级
64242c5d
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
70 addition
and
45 deletion
+70
-45
paddle/fluid/operators/detail/bytebuffer_stream.cc
paddle/fluid/operators/detail/bytebuffer_stream.cc
+1
-1
paddle/fluid/operators/detail/bytebuffer_stream.h
paddle/fluid/operators/detail/bytebuffer_stream.h
+5
-3
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+10
-6
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+7
-8
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+5
-2
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+6
-3
paddle/fluid/operators/detail/grpc_service.h
paddle/fluid/operators/detail/grpc_service.h
+1
-1
paddle/fluid/operators/detail/proto_encoder_helper.h
paddle/fluid/operators/detail/proto_encoder_helper.h
+6
-4
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+7
-5
paddle/fluid/operators/detail/sendrecvop_utils.h
paddle/fluid/operators/detail/sendrecvop_utils.h
+1
-1
paddle/fluid/operators/detail/serde_test.cc
paddle/fluid/operators/detail/serde_test.cc
+4
-4
paddle/fluid/operators/detail/simple_block_queue.h
paddle/fluid/operators/detail/simple_block_queue.h
+2
-2
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+11
-3
paddle/fluid/operators/detail/variable_response.h
paddle/fluid/operators/detail/variable_response.h
+4
-2
未找到文件。
paddle/fluid/operators/detail/bytebuffer_stream.cc
浏览文件 @
c1c5e166
...
...
@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "bytebuffer_stream.h"
#include "
paddle/fluid/operators/detail/
bytebuffer_stream.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/detail/bytebuffer_stream.h
浏览文件 @
c1c5e166
...
...
@@ -19,9 +19,11 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <vector>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h"
namespace
grpc
{
// A ZeroCopyInputStream that reads from grpc_byte_buffer
...
...
@@ -56,7 +58,7 @@ class GrpcBufferReader final
*
data
=
GRPC_SLICE_START_PTR
(
slice_
)
+
GRPC_SLICE_LENGTH
(
slice_
)
-
backup_count_
;
GPR_CODEGEN_ASSERT
(
backup_count_
<=
INT_MAX
);
*
size
=
(
int
)
backup_count_
;
*
size
=
static_cast
<
int
>
(
backup_count_
)
;
backup_count_
=
0
;
return
true
;
}
...
...
@@ -68,7 +70,7 @@ class GrpcBufferReader final
*
data
=
GRPC_SLICE_START_PTR
(
slice_
);
// On win x64, int is only 32bit
GPR_CODEGEN_ASSERT
(
GRPC_SLICE_LENGTH
(
slice_
)
<=
INT_MAX
);
byte_count_
+=
*
size
=
(
int
)
GRPC_SLICE_LENGTH
(
slice_
);
byte_count_
+=
*
size
=
static_cast
<
int
>
(
GRPC_SLICE_LENGTH
(
slice_
)
);
return
true
;
}
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
c1c5e166
...
...
@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include <sys/time.h>
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
namespace
paddle
{
...
...
@@ -52,7 +56,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendVariable"
,
req
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
)
);
});
req_count_
++
;
...
...
@@ -64,7 +68,7 @@ void ProcGetResponse(const VarHandle& var_h,
// const sendrecv::VariableMessage& ret_msg) {
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
framework
::
Variable
*
outvar
=
NULL
;
DeserializeFromByteBuffer
(
ret_msg
,
*
var_h
.
ctx
,
var_h
.
scope
,
outvar
);
DeserializeFromByteBuffer
(
ret_msg
,
*
var_h
.
ctx
,
var_h
.
scope
,
&
outvar
);
}
template
<
typename
T
>
...
...
@@ -109,7 +113,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/GetVariable"
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
)
);
});
req_count_
++
;
...
...
@@ -126,7 +130,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
BATCH_BARRIER_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
)
);
req_count_
++
;
}
...
...
@@ -138,7 +142,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
)
);
req_count_
++
;
}
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
c1c5e166
...
...
@@ -14,10 +14,9 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <grpc/support/log.h>
#include <time.h>
#include <chrono>
#include <chrono> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
...
...
@@ -25,11 +24,11 @@ limitations under the License. */
#include <string>
#include <vector>
#include
<grpc++/generic/generic_stub.h>
#include
<grpc++/grpc++.h>
#include
<grpc++/support/byte_buffer.h>
#include
<grpc++/support/slice.h>
#include
"grpc++/generic/generic_stub.h"
#include
"grpc++/grpc++.h"
#include
"grpc++/support/byte_buffer.h"
#include
"grpc++/support/slice.h"
#include "grpc/support/log.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
c1c5e166
...
...
@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using
::
grpc
::
ServerAsyncResponseWriter
;
namespace
paddle
{
...
...
@@ -205,7 +208,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
// FIXME(typhoonzero): change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
TryToRegisterNewOne
();
...
...
@@ -222,7 +225,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
if
(
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
cq_name
==
"cq_send"
)
WaitCond
(
0
);
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
)
;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
c1c5e166
...
...
@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
...
...
@@ -62,7 +64,8 @@ class AsyncGRPCServer final {
void
ShutDown
();
protected:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewGetOne
();
...
...
paddle/fluid/operators/detail/grpc_service.h
浏览文件 @
c1c5e166
...
...
@@ -114,5 +114,5 @@ class GrpcService final {
};
}
// namespace detail
}
// namespace operator
}
// namespace operator
s
}
// namespace paddle
paddle/fluid/operators/detail/proto_encoder_helper.h
浏览文件 @
c1c5e166
...
...
@@ -19,7 +19,9 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <string>
#include "grpc++/grpc++.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
...
...
@@ -142,6 +144,6 @@ class ProtoEncodeHelper {
char
*
limit_
;
// Just for CHECKs
};
}
// detail
}
// operators
}
// paddle
}
//
namespace
detail
}
//
namespace
operators
}
//
namespace
paddle
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
c1c5e166
...
...
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include <sys/time.h>
#include <thread>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
...
...
@@ -42,7 +44,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void
*
buf
=
malloc
(
1024
);
void
*
payload
=
nullptr
;
size_t
payload_size
;
ProtoEncodeHelper
e
(
(
char
*
)
buf
,
1024
);
ProtoEncodeHelper
e
(
static_cast
<
char
*>
(
buf
)
,
1024
);
e
.
WriteString
(
VarMsg
::
kVarnameFieldNumber
,
name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
e
.
WriteUint64
(
VarMsg
::
kTypeFieldNumber
,
0
);
...
...
@@ -152,7 +154,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
framework
::
proto
::
VarType_Type_SELECTED_ROWS
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
(
char
*
)
buf
,
128
);
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
)
,
128
);
// NOTE: rows is of type int64_t
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
...
...
@@ -181,10 +183,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
*
&
var
)
{
framework
::
Variable
*
*
var
)
{
operators
::
detail
::
VariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
var
=
resp
.
GetVar
();
*
var
=
resp
.
GetVar
();
}
}
// namespace detail
...
...
paddle/fluid/operators/detail/sendrecvop_utils.h
浏览文件 @
c1c5e166
...
...
@@ -51,7 +51,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
*
&
var
);
framework
::
Variable
*
*
var
);
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
...
...
paddle/fluid/operators/detail/serde_test.cc
浏览文件 @
c1c5e166
...
...
@@ -14,9 +14,9 @@ limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include <thread>
// NOLINT
#include
<google/protobuf/text_format.h>
#include
"google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
...
...
@@ -102,12 +102,12 @@ void RunSerdeTestSelectedRows(platform::Place place) {
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
->
data
<
float
>
());
}
const
int64
_t
*
rows_data2
=
rows2
->
data
();
const
size
_t
*
rows_data2
=
rows2
->
data
();
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
32.7
);
}
for
(
in
t
i
=
0
;
i
<
rows2
->
size
();
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
rows2
->
size
();
++
i
)
{
EXPECT_EQ
(
rows_data2
[
i
],
i
);
}
EXPECT_EQ
(
slr2
->
height
(),
1000
);
...
...
paddle/fluid/operators/detail/simple_block_queue.h
浏览文件 @
c1c5e166
...
...
@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once
#include <condition_variable>
#include <condition_variable>
// NOLINT
#include <deque>
#include <mutex>
#include <mutex>
// NOLINT
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
c1c5e166
...
...
@@ -13,7 +13,13 @@
// limitations under the License.
#include "paddle/fluid/operators/detail/variable_response.h"
#include <string.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...
...
@@ -108,7 +114,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
bool
VariableResponse
::
CopyLodTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
)
{
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
)
{
auto
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
dims
);
...
...
@@ -144,14 +151,15 @@ inline framework::DDim GetDims(
bool
VariableResponse
::
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
)
{
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
DDim
&
dims
,
int
length
)
{
auto
var
=
scope_
->
FindVar
(
meta_
.
varname
());
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
set_height
(
meta_
.
slr_height
());
auto
*
tensor
=
slr
->
mutable_value
();
tensor
->
Resize
(
dims
);
PADDLE_ENFORCE_EQ
(
tensor
->
numel
(
),
static_cast
<
size_t
>
(
tensor
->
numel
()
),
length
/
framework
::
SizeOfType
(
paddle
::
operators
::
detail
::
ToTypeIndex
(
meta_
.
data_type
())));
void
*
tensor_data
=
tensor
->
mutable_data
(
...
...
paddle/fluid/operators/detail/variable_response.h
浏览文件 @
c1c5e166
...
...
@@ -14,6 +14,8 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -60,14 +62,14 @@ class VariableResponse {
private:
bool
CopySelectRowsTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
);
const
framework
::
DDim
&
dims
,
int
length
);
bool
CopySelectRowsData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
int
length
);
bool
CopyLodTensorData
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
ctx
,
framework
::
DDim
&
dims
,
int
length
);
const
framework
::
DDim
&
dims
,
int
length
);
private:
const
framework
::
Scope
*
scope_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录