Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
900b4cdd
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
900b4cdd
编写于
6月 28, 2019
作者:
W
Wei Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Expose CXXTrainer related API to python
上级
c998292e
变更
39
隐藏空白更改
内联
并排
Showing
39 changed file
with
870 addition
and
60 deletion
+870
-60
cmake/version.cmake
cmake/version.cmake
+2
-1
paddle/fluid/lite/api/cxx_api.cc
paddle/fluid/lite/api/cxx_api.cc
+11
-0
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+23
-0
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+6
-1
paddle/fluid/lite/core/mir/ssa_graph_test.cc
paddle/fluid/lite/core/mir/ssa_graph_test.cc
+3
-1
paddle/fluid/lite/core/program.cc
paddle/fluid/lite/core/program.cc
+2
-0
paddle/fluid/lite/core/scope.h
paddle/fluid/lite/core/scope.h
+6
-0
paddle/fluid/lite/core/variable.h
paddle/fluid/lite/core/variable.h
+6
-1
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+1
-1
paddle/fluid/lite/kernels/x86/CMakeLists.txt
paddle/fluid/lite/kernels/x86/CMakeLists.txt
+5
-3
paddle/fluid/lite/kernels/x86/elementwise_compute.cc
paddle/fluid/lite/kernels/x86/elementwise_compute.cc
+13
-10
paddle/fluid/lite/kernels/x86/elementwise_compute.h
paddle/fluid/lite/kernels/x86/elementwise_compute.h
+9
-3
paddle/fluid/lite/kernels/x86/sgd_compute.cc
paddle/fluid/lite/kernels/x86/sgd_compute.cc
+1
-0
paddle/fluid/lite/kernels/x86/uniform_random_compute.cc
paddle/fluid/lite/kernels/x86/uniform_random_compute.cc
+67
-0
paddle/fluid/lite/model_parser/compatible_pb.cc
paddle/fluid/lite/model_parser/compatible_pb.cc
+4
-0
paddle/fluid/lite/model_parser/cpp/op_desc.cc
paddle/fluid/lite/model_parser/cpp/op_desc.cc
+1
-0
paddle/fluid/lite/model_parser/cpp/op_desc.h
paddle/fluid/lite/model_parser/cpp/op_desc.h
+11
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+1
-0
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+6
-2
paddle/fluid/lite/operators/activation_ops.cc
paddle/fluid/lite/operators/activation_ops.cc
+15
-0
paddle/fluid/lite/operators/elementwise_ops.cc
paddle/fluid/lite/operators/elementwise_ops.cc
+11
-7
paddle/fluid/lite/operators/fill_constant_op.cc
paddle/fluid/lite/operators/fill_constant_op.cc
+1
-1
paddle/fluid/lite/operators/mean_op.cc
paddle/fluid/lite/operators/mean_op.cc
+2
-2
paddle/fluid/lite/operators/mul_op.cc
paddle/fluid/lite/operators/mul_op.cc
+28
-19
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+2
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+11
-1
paddle/fluid/lite/operators/sgd_op.cc
paddle/fluid/lite/operators/sgd_op.cc
+4
-3
paddle/fluid/lite/operators/sgd_op.h
paddle/fluid/lite/operators/sgd_op.h
+1
-1
paddle/fluid/lite/operators/uniform_random_op.cc
paddle/fluid/lite/operators/uniform_random_op.cc
+45
-0
paddle/fluid/lite/operators/uniform_random_op.h
paddle/fluid/lite/operators/uniform_random_op.h
+50
-0
paddle/fluid/lite/python/lite_test.py
paddle/fluid/lite/python/lite_test.py
+103
-0
paddle/fluid/lite/tools/build.sh
paddle/fluid/lite/tools/build.sh
+24
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+10
-3
paddle/fluid/pybind/executor_lite.cc
paddle/fluid/pybind/executor_lite.cc
+189
-0
paddle/fluid/pybind/executor_lite.h
paddle/fluid/pybind/executor_lite.h
+26
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+5
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+1
-0
python/paddle/fluid/cxx_trainer.py
python/paddle/fluid/cxx_trainer.py
+163
-0
未找到文件。
cmake/version.cmake
浏览文件 @
900b4cdd
...
...
@@ -3,7 +3,8 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set
(
tmp_version
"HEAD"
)
set
(
TAG_VERSION_REGEX
"[0-9]+
\\
.[0-9]+
\\
.[0-9]+(
\\
.(a|b|rc)
\\
.[0-9]+)?"
)
set
(
COMMIT_VERSION_REGEX
"[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+"
)
set
(
LATEST_PADDLE_VERSION
"latest"
)
# set(LATEST_PADDLE_VERSION "latest")
set
(
LATEST_PADDLE_VERSION
"0.0.0"
)
while
(
"
${
PADDLE_VERSION
}
"
STREQUAL
""
)
# Check current branch name
...
...
paddle/fluid/lite/api/cxx_api.cc
浏览文件 @
900b4cdd
...
...
@@ -79,5 +79,16 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const {
return
&
var
->
Get
<
lite
::
Tensor
>
();
}
#ifdef LITE_WITH_X86
void
Predictor
::
FeedVars
(
const
std
::
vector
<
framework
::
Tensor
>
&
tensors
)
{
auto
var
=
scope_
->
FindVar
(
"feed"
);
auto
&
feed_list
=
*
(
var
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
());
feed_list
.
resize
(
tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
feed_list
[
i
].
ShareDataWith
(
tensors
[
i
]);
}
#endif
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
900b4cdd
...
...
@@ -24,6 +24,10 @@
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#ifdef LITE_WITH_X86
#include "paddle/fluid/framework/program_desc.h"
#endif
namespace
paddle
{
namespace
lite
{
...
...
@@ -63,6 +67,15 @@ class Predictor {
// This method is disabled in mobile, for unnecessary dependencies required.
void
SaveModel
(
const
std
::
string
&
dir
);
#ifdef LITE_WITH_X86
void
Run
(
const
std
::
vector
<
framework
::
Tensor
>&
tensors
)
{
FeedVars
(
tensors
);
program_
->
Run
();
}
void
FeedVars
(
const
std
::
vector
<
framework
::
Tensor
>&
tensors
);
#endif
private:
Optimizer
optimizer_
;
framework
::
proto
::
ProgramDesc
program_desc_
;
...
...
@@ -105,6 +118,16 @@ class CXXTrainer {
return
main_program_executor_
;
}
#ifdef LITE_WITH_X86
Predictor
&
BuildMainProgramExecutor
(
framework
::
ProgramDesc
&
desc
)
{
// NOLINT
return
BuildMainProgramExecutor
(
*
desc
.
Proto
());
}
void
RunStartupProgram
(
framework
::
ProgramDesc
&
desc
)
{
// NOLINT
RunStartupProgram
(
*
desc
.
Proto
());
}
#endif
// Run the startup program. It just executes once, no cache needed.
void
RunStartupProgram
(
const
framework
::
proto
::
ProgramDesc
&
desc
,
int
block_id
=
0
)
{
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
900b4cdd
...
...
@@ -20,14 +20,19 @@ endif()
proto_library
(
framework_proto_lite SRCS framework.proto
)
cc_library
(
kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_lite op_params_lite framework_proto_lite
${
tensor_lite
}
)
if
(
LITE_WITH_X86
)
cc_library
(
variable_lite SRCS variable.cc DEPS framework_proto
)
cc_library
(
types_lite SRCS types.cc DEPS framework_proto
)
else
()
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
types_lite SRCS types.cc
)
endif
()
cc_library
(
op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite
)
cc_library
(
scope_lite SRCS scope.cc DEPS
${
tensor_lite
}
)
cc_library
(
cpu_info_lite SRCS cpu_info.cc
)
lite_cc_library
(
context_lite SRCS context.cc DEPS
${
tensor_lite
}
any_lite cpu_info_lite eigen3
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite kernel_lite
cpp_op_desc_lite
${
tensor_lite
}
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
lite_cc_library
(
program_lite SRCS program.cc
...
...
paddle/fluid/lite/core/mir/ssa_graph_test.cc
浏览文件 @
900b4cdd
...
...
@@ -52,4 +52,6 @@ TEST(SSAGraph, test) {
}
// namespace paddle
USE_LITE_OP
(
fc
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
def
);
#ifdef LITE_WITH_X86
// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
#endif
paddle/fluid/lite/core/program.cc
浏览文件 @
900b4cdd
...
...
@@ -64,6 +64,7 @@ void RuntimeProgram::SaveParams(const std::string &dir,
void
Program
::
Build
(
const
framework
::
proto
::
ProgramDesc
&
program
)
{
CHECK
(
ops_
.
empty
())
<<
"Executor duplicate Build found"
;
// Create operators.
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
lite
::
OpDesc
op_desc_dummy
(
proto_op_desc
);
...
...
@@ -98,6 +99,7 @@ void Program::PrepareWorkspace(const framework::proto::ProgramDesc &program) {
}
else
{
if
(
var_desc
.
Name
()
==
"feed"
||
var_desc
.
Name
()
==
"fetch"
)
continue
;
weights_
.
push_back
(
var_desc
.
Name
());
if
(
var_desc
.
Persistable
())
scope_
->
Var
(
var_desc
.
Name
());
}
}
}
...
...
paddle/fluid/lite/core/scope.h
浏览文件 @
900b4cdd
...
...
@@ -27,6 +27,12 @@ namespace lite {
class
Scope
final
{
public:
Scope
()
{}
// delete below two functions to allow pybind to recognise it cannot make a
// copy
// link:
// https://stackoverflow.com/questions/53807248/pybind11-returning-a-pointer-to-a-container-of-unique-ptr
Scope
(
const
Scope
&
)
=
delete
;
Scope
&
operator
=
(
const
Scope
&
)
=
delete
;
~
Scope
();
Scope
&
NewScope
()
const
;
...
...
paddle/fluid/lite/core/variable.h
浏览文件 @
900b4cdd
...
...
@@ -15,12 +15,15 @@
#pragma once
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
using
FeedFetchList
=
std
::
vector
<
lite
::
Tensor
>
;
class
Variable
{
public:
template
<
typename
T
>
...
...
@@ -40,7 +43,9 @@ class Variable {
}
private:
variant
<
int
,
float
,
std
::
string
,
lite
::
Tensor
>
blob_
;
// variant<int, float, std::string, lite::Tensor> blob_;
variant
<
int
,
float
,
std
::
string
,
lite
::
Tensor
,
std
::
vector
<
lite
::
Tensor
>>
blob_
;
};
}
// namespace lite
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
900b4cdd
...
...
@@ -29,7 +29,7 @@ class FeedCompute
auto
&
param
=
Param
<
operators
::
FeedParam
>
();
VLOG
(
4
)
<<
"feed_list.size: "
<<
param
.
feed_list
->
size
();
VLOG
(
4
)
<<
"col "
<<
param
.
col
;
const
lite
::
Tensor
&
feed_item
=
(
*
param
.
feed_list
)[
0
];
const
lite
::
Tensor
&
feed_item
=
(
*
param
.
feed_list
)[
param
.
col
];
param
.
out
->
ShareDataWith
(
feed_item
);
}
};
...
...
paddle/fluid/lite/kernels/x86/CMakeLists.txt
浏览文件 @
900b4cdd
...
...
@@ -18,6 +18,7 @@ cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )
cc_library
(
conv_compute_x86 SRCS conv_compute.cc DEPS
${
lite_kernel_deps
}
blas im2col vol2col
)
cc_library
(
pool_compute_x86 SRCS pool_compute.cc DEPS
${
lite_kernel_deps
}
pooling
)
cc_library
(
batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS
${
lite_kernel_deps
}
)
lite_cc_test
(
test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86
)
lite_cc_test
(
test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86
)
...
...
@@ -47,6 +48,7 @@ set(x86_kernels
conv_compute_x86
pool_compute_x86
batch_norm_compute_x86
)
set
(
x86_kernels
"
${
x86_kernels
}
"
CACHE INTERNAL
"x86 kernels"
)
uniform_random_compute_x86
sgd_compute_x86
CACHE INTERNAL
"x86 kernels"
)
paddle/fluid/lite/kernels/x86/elementwise_compute.cc
浏览文件 @
900b4cdd
...
...
@@ -22,9 +22,19 @@ REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW,
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
elementwise_
sub_gra
d
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
Elementwise
Sub
Compute
<
float
>
,
REGISTER_LITE_KERNEL
(
elementwise_
ad
d
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
Elementwise
Add
Compute
<
float
>
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
#ifdef LITE_WITH_X86
REGISTER_LITE_KERNEL
(
elementwise_sub_grad
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
ElementwiseSubGradCompute
<
float
>
,
def
)
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
paddle
::
framework
::
GradVarName
(
"Out"
),
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
paddle
::
framework
::
GradVarName
(
"X"
),
...
...
@@ -32,11 +42,4 @@ REGISTER_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW,
.
BindOutput
(
paddle
::
framework
::
GradVarName
(
"Y"
),
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
elementwise_add
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
ElementwiseAddCompute
<
float
>
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindInput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
#endif
paddle/fluid/lite/kernels/x86/elementwise_compute.h
浏览文件 @
900b4cdd
...
...
@@ -68,6 +68,7 @@ struct SubGradDY {
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
-
dout
;
}
};
#ifdef LITE_WITH_X86
template
<
typename
T
>
class
ElementwiseSubGradCompute
:
public
KernelLite
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
{
...
...
@@ -79,20 +80,25 @@ class ElementwiseSubGradCompute
CHECK
(
context
.
x86_device_context
());
param
.
X_grad
->
template
mutable_data
<
T
>();
param
.
Y_grad
->
template
mutable_data
<
T
>();
// skip out, x, y
auto
dout
=
param
.
Out_grad
->
raw_tensor
();
auto
dx
=
param
.
X_grad
->
raw_tensor
();
auto
dy
=
param
.
Y_grad
->
raw_tensor
();
framework
::
Tensor
*
dy
=
nullptr
;
if
(
param
.
Y_grad
)
{
param
.
Y_grad
->
template
mutable_data
<
T
>();
dy
=
&
param
.
Y_grad
->
raw_tensor
();
}
auto
&
skip
=
dout
;
paddle
::
operators
::
ElemwiseExplicitGradCompute
<
platform
::
CPUDeviceContext
,
T
,
SubGradDX
<
T
>
,
SubGradDY
<
T
>>
(
*
context
.
x86_execution_context
(),
skip
,
skip
,
skip
,
dout
,
param
.
axis
,
&
dx
,
&
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
&
dx
,
dy
,
SubGradDX
<
T
>
(),
SubGradDY
<
T
>
());
}
virtual
~
ElementwiseSubGradCompute
()
=
default
;
};
#endif
template
<
typename
T
>
class
ElementwiseAddCompute
...
...
paddle/fluid/lite/kernels/x86/sgd_compute.cc
浏览文件 @
900b4cdd
...
...
@@ -49,6 +49,7 @@ class SGDCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
const
T
*
param_data
=
param
->
template
data
<
T
>();
const
T
*
grad_data
=
grad
->
template
data
<
T
>();
int64_t
rows_idx
=
0
;
T
*
out_data
=
param_out
->
template
mutable_data
<
T
>(
context
.
x86_device_context
()
->
GetPlace
());
...
...
paddle/fluid/lite/kernels/x86/uniform_random_compute.cc
0 → 100644
浏览文件 @
900b4cdd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
x86
{
template
<
typename
T
>
class
UniformRandomCompute
:
public
KernelLite
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
{
auto
&
context
=
ctx_
->
As
<
X86Context
>
();
auto
&
param
=
*
param_
.
get_mutable
<
operators
::
UniformRandomParam
>
();
CHECK
(
context
.
x86_device_context
());
auto
*
param_out
=
&
param
.
Out
->
raw_tensor
();
T
*
data
=
param_out
->
mutable_data
<
T
>
(
context
.
x86_device_context
()
->
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
param
.
seed
);
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
}
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
param
.
min
),
static_cast
<
T
>
(
param
.
max
));
int64_t
size
=
param_out
->
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
virtual
~
UniformRandomCompute
()
=
default
;
};
}
// namespace x86
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
// float
REGISTER_LITE_KERNEL
(
uniform_random
,
kX86
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
x86
::
UniformRandomCompute
<
float
>
,
def
)
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/model_parser/compatible_pb.cc
浏览文件 @
900b4cdd
...
...
@@ -72,6 +72,10 @@ void AttrsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) {
cpp_desc
->
SetAttr
<
std
::
vector
<
std
::
string
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
std
::
string
>>
(
name
));
break
;
case
AttrType
::
LONGS
:
cpp_desc
->
SetAttr
<
std
::
vector
<
int64_t
>>
(
name
,
pb_desc
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
name
));
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported attr type found "
<<
static_cast
<
int
>
(
type
);
}
...
...
paddle/fluid/lite/model_parser/cpp/op_desc.cc
浏览文件 @
900b4cdd
...
...
@@ -34,6 +34,7 @@ SET_ATTR_IMPL(bool, BOOLEAN);
SET_ATTR_IMPL
(
std
::
vector
<
int
>
,
INTS
);
SET_ATTR_IMPL
(
std
::
vector
<
float
>
,
FLOATS
);
SET_ATTR_IMPL
(
std
::
vector
<
std
::
string
>
,
STRINGS
);
SET_ATTR_IMPL
(
std
::
vector
<
int64_t
>
,
LONGS
);
std
::
pair
<
OpDesc
::
attrs_t
::
const_iterator
,
OpDesc
::
attr_types_t
::
const_iterator
>
FindAttr
(
const
cpp
::
OpDesc
&
desc
,
const
std
::
string
&
name
)
{
...
...
paddle/fluid/lite/model_parser/cpp/op_desc.h
浏览文件 @
900b4cdd
...
...
@@ -58,6 +58,12 @@ class OpDesc : public OpDescAPI {
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>*
mutable_outputs
()
{
return
&
outputs_
;
}
bool
HasInput
(
const
std
::
string
&
param
)
const
{
auto
it
=
inputs_
.
find
(
param
);
return
it
!=
inputs_
.
end
();
}
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
inputs_
.
find
(
param
);
CHECK
(
it
!=
inputs_
.
end
());
...
...
@@ -75,6 +81,11 @@ class OpDesc : public OpDescAPI {
return
res
;
}
bool
HasOutput
(
const
std
::
string
&
param
)
const
{
auto
it
=
outputs_
.
find
(
param
);
return
it
!=
outputs_
.
end
();
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
override
{
auto
it
=
outputs_
.
find
(
param
);
CHECK
(
it
!=
outputs_
.
end
());
...
...
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
900b4cdd
...
...
@@ -121,6 +121,7 @@ GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL
(
std
::
vector
<
float
>
,
floats
);
GET_ATTRS_IMPL
(
std
::
vector
<
std
::
string
>
,
strings
);
GET_ATTR_IMPL
(
std
::
string
,
s
);
GET_ATTRS_IMPL
(
std
::
vector
<
int64_t
>
,
longs
);
}
// namespace pb
}
// namespace lite
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
900b4cdd
...
...
@@ -17,7 +17,9 @@ cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS})
cc_library
(
fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite
${
op_DEPS
}
)
cc_library
(
mean_op_lite SRCS mean_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
fill_constant_op_lite SRCS fill_constant_op.cc DEPS
${
op_DEPS
}
)
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
cc_library
(
sgd_op_lite SRCS sgd_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
uniform_random_op_lite SRCS uniform_random_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
op_params_lite SRCS op_params.cc DEPS
${
tensor_lite
}
any_lite framework_proto_lite
)
cc_library
(
dropout_op_lite SRCS dropout_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
concat_op_lite SRCS concat_op.cc DEPS
${
op_DEPS
}
)
...
...
@@ -52,7 +54,9 @@ set(ops_lite
transpose_op_lite
fake_quant
fake_dequant
PARENT_SCOPE
)
sgd_op_lite
uniform_random_op_lite
CACHE INTERNAL
"ops lite"
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite
...
...
paddle/fluid/lite/operators/activation_ops.cc
浏览文件 @
900b4cdd
...
...
@@ -72,6 +72,21 @@ class ActivationGradOp : public OpLite {
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_grad_name
);
param_
.
X_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
X_grad_name
);
if
(
opdesc
.
HasInput
(
"X"
))
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
}
else
{
param_
.
X
=
param_
.
X_grad
;
}
if
(
opdesc
.
HasInput
(
"Out"
))
{
auto
Out_name
=
opdesc
.
Input
(
"Out"
).
front
();
param_
.
Out
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
}
else
{
param_
.
Out
=
param_
.
Out_grad
;
}
return
true
;
}
...
...
paddle/fluid/lite/operators/elementwise_ops.cc
浏览文件 @
900b4cdd
...
...
@@ -48,31 +48,35 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
bool
ElementwiseGradExplicitOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
X_grad
);
CHECK_OR_FALSE
(
param_
.
Y_grad
);
CHECK_OR_FALSE
(
param_
.
Out_grad
);
return
true
;
}
bool
ElementwiseGradExplicitOp
::
InferShape
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
param_
.
Y_grad
->
Resize
(
param_
.
Y
->
dims
());
if
(
param_
.
Y_grad
)
param_
.
Y_grad
->
Resize
(
param_
.
Y
->
dims
());
return
true
;
}
bool
ElementwiseGradExplicitOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
1UL
);
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
2UL
);
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
auto
X_grad
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
if
(
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
size
()
>
0
)
{
auto
Y_grad
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_grad
);
}
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_grad
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
#endif
}
// namespace operators
...
...
paddle/fluid/lite/operators/fill_constant_op.cc
浏览文件 @
900b4cdd
...
...
@@ -36,7 +36,7 @@ class FillConstantOp : public OpLite {
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
Out
=
GetMutableVar
<
Tensor
>
(
scope
,
Out_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
dtype
=
opdesc
.
GetAttr
<
int
>
(
"dtype"
);
param_
.
shape
=
opdesc
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
param_
.
value
=
opdesc
.
GetAttr
<
float
>
(
"value"
);
...
...
paddle/fluid/lite/operators/mean_op.cc
浏览文件 @
900b4cdd
...
...
@@ -51,7 +51,7 @@ class MeanOp : public OpLite {
std
::
string
DebugString
()
const
override
{
return
"mean"
;
}
private:
mutable
operators
::
Elementwise
Param
param_
;
mutable
operators
::
Mean
Param
param_
;
};
#ifdef LITE_WITH_X86
...
...
@@ -73,7 +73,7 @@ class MeanGradOp : public OpLite {
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
3
UL
);
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
2
UL
);
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Out_grad_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
paddle/fluid/lite/operators/mul_op.cc
浏览文件 @
900b4cdd
...
...
@@ -31,16 +31,18 @@ bool MulOpLite::CheckShape() const {
CHECK_GT_OR_FALSE
(
x_dims
.
size
(),
static_cast
<
size_t
>
(
param_
.
x_num_col_dims
));
CHECK_GT_OR_FALSE
(
y_dims
.
size
(),
static_cast
<
size_t
>
(
param_
.
y_num_col_dims
));
// auto x_mat_dims =
// framework::flatten_to_2d(x_dims.data(), param_.x_num_col_dims);
// auto y_mat_dims =
// framework::flatten_to_2d(y_dims.data(), param_.y_num_col_dims);
// PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0],
// "First matrix's width must be equal with second matrix's
// "
// "height. %s, %s",
// x_mat_dims[1], y_mat_dims[0]);
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
auto
x_mat_dims
=
framework
::
flatten_to_2d
(
x_dims
.
data
(),
param_
.
x_num_col_dims
);
auto
y_mat_dims
=
framework
::
flatten_to_2d
(
y_dims
.
data
(),
param_
.
y_num_col_dims
);
PADDLE_ENFORCE_EQ
(
x_mat_dims
[
1
],
y_mat_dims
[
0
],
"First matrix's width must be equal with second matrix's"
"height. %s, %s"
,
x_mat_dims
[
1
],
y_mat_dims
[
0
]);
#endif
return
true
;
}
...
...
@@ -73,30 +75,34 @@ bool MulGradOpLite::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
y
);
CHECK_OR_FALSE
(
param_
.
output_grad
);
CHECK_OR_FALSE
(
param_
.
x_grad
);
CHECK_OR_FALSE
(
param_
.
y_grad
);
return
true
;
}
bool
MulGradOpLite
::
InferShape
()
const
{
param_
.
x_grad
->
Resize
(
param_
.
x
->
dims
());
param_
.
y_grad
->
Resize
(
param_
.
y
->
dims
());
if
(
param_
.
x_grad
)
param_
.
x_grad
->
Resize
(
param_
.
x
->
dims
());
if
(
param_
.
y_grad
)
param_
.
y_grad
->
Resize
(
param_
.
y
->
dims
());
return
true
;
}
bool
MulGradOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
X_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
Y_name
=
op_desc
.
Input
(
"Y"
).
front
();
auto
Out_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
auto
Out_grad_name
=
op_desc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
if
(
op_desc
.
Output
(
framework
::
GradVarName
(
"X"
)).
size
())
{
auto
X_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
param_
.
x_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_grad_name
);
}
if
(
op_desc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
size
())
{
auto
Y_grad_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
param_
.
y_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Y_grad_name
);
}
param_
.
x
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
output_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_grad_name
);
param_
.
x_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_grad_name
);
param_
.
y_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Y_grad_name
);
return
true
;
}
...
...
@@ -107,3 +113,6 @@ bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
}
// namespace paddle
REGISTER_LITE_OP
(
mul
,
paddle
::
lite
::
operators
::
MulOpLite
);
#ifdef LITE_WITH_X86
REGISTER_LITE_OP
(
mul_grad
,
paddle
::
lite
::
operators
::
MulGradOpLite
);
#endif
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
900b4cdd
...
...
@@ -66,6 +66,7 @@ class MulOpLite : public OpLite {
mutable
MulParam
param_
;
};
#ifdef LITE_WITH_X86
class
MulGradOpLite
:
public
OpLite
{
public:
MulGradOpLite
()
{}
...
...
@@ -85,6 +86,7 @@ class MulGradOpLite : public OpLite {
private:
mutable
MulGradParam
param_
;
};
#endif
}
// namespace operators
}
// namespace lite
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
900b4cdd
...
...
@@ -36,7 +36,7 @@ using param_t = Any;
/// ----------------------- Functional operators ------------------------------
struct
FeedParam
{
const
std
::
vector
<
lite
::
Tensor
>*
feed_list
{};
std
::
vector
<
lite
::
Tensor
>*
feed_list
{};
lite
::
Tensor
*
out
{};
int
col
;
};
...
...
@@ -317,6 +317,16 @@ struct SGDParam {
lite
::
Tensor
*
ParamOut
{};
};
/// ----------------------- uniform_random operators ----------------------
struct
UniformRandomParam
{
std
::
vector
<
int64_t
>
shape
{};
float
min
{
-
1.0
f
};
float
max
{
1.0
f
};
int
seed
{
0
};
int
dtype
{
framework
::
proto
::
VarType
::
FP32
};
lite
::
Tensor
*
Out
{};
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/sgd_op.cc
浏览文件 @
900b4cdd
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "
/paddle/
paddle/fluid/lite/operators/sgd_op.h"
#include "paddle/fluid/lite/operators/sgd_op.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
@@ -30,13 +30,14 @@ bool SGDOpLite::CheckShape() const {
bool
SGDOpLite
::
InferShape
()
const
{
auto
lr_dims
=
param_
.
LearningRate
->
dims
().
data
();
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
CHECK_EQ_OR_FALSE
(
framework
::
product
(
lr_dims
),
1
);
#endif
param_
.
ParamOut
->
Resize
(
param_
.
Param
->
dims
());
return
true
;
}
bool
SGDOpLite
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK_EQ
(
opdesc
.
Inputs
().
size
(),
3UL
);
bool
SGDOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
Param_name
=
opdesc
.
Input
(
"Param"
).
front
();
auto
LearningRate_name
=
opdesc
.
Input
(
"LearningRate"
).
front
();
auto
Grad_name
=
opdesc
.
Input
(
"Grad"
).
front
();
...
...
paddle/fluid/lite/operators/sgd_op.h
浏览文件 @
900b4cdd
...
...
@@ -37,7 +37,7 @@ class SGDOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"sgd"
;
}
...
...
paddle/fluid/lite/operators/uniform_random_op.cc
0 → 100644
浏览文件 @
900b4cdd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/uniform_random_op.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
UniformRandomOpLite
::
CheckShape
()
const
{
return
true
;
}
bool
UniformRandomOpLite
::
InferShape
()
const
{
param_
.
Out
->
Resize
(
param_
.
shape
);
return
true
;
}
bool
UniformRandomOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
shape
=
opdesc
.
GetAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
);
param_
.
min
=
opdesc
.
GetAttr
<
float
>
(
"min"
);
param_
.
max
=
opdesc
.
GetAttr
<
float
>
(
"max"
);
param_
.
seed
=
opdesc
.
GetAttr
<
int
>
(
"seed"
);
param_
.
dtype
=
opdesc
.
GetAttr
<
int
>
(
"dtype"
);
param_
.
Out
=
GetMutableVar
<
Tensor
>
(
scope
,
opdesc
.
Output
(
"Out"
).
front
());
return
true
;
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
uniform_random
,
paddle
::
lite
::
operators
::
UniformRandomOpLite
);
paddle/fluid/lite/operators/uniform_random_op.h
0 → 100644
浏览文件 @
900b4cdd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
UniformRandomOpLite
:
public
OpLite
{
public:
UniformRandomOpLite
()
{}
explicit
UniformRandomOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"uniform_random"
;
}
private:
mutable
UniformRandomParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/python/lite_test.py
0 → 100644
浏览文件 @
900b4cdd
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.compiler
as
compiler
import
paddle.fluid.core
as
core
import
paddle.fluid.core.lite
as
lite
import
paddle.fluid.layers
as
layers
import
numpy
as
np
import
unittest
from
paddle.fluid.cxx_trainer
import
add_feed_fetch_op
def
_as_lodtensor
(
data
,
place
):
# single tensor case
tensor
=
core
.
LoDTensor
()
tensor
.
set
(
data
,
place
)
return
tensor
data_label
=
[[
0.753544
,
0.772977
,
0.646915
,
0.747543
,
0.528923
,
0.0517749
,
0.248678
,
0.75932
,
0.960376
,
0.606618
]]
data_a
=
[[
0.874445
,
0.21623
,
0.713262
,
0.702672
,
0.396977
,
0.828285
,
0.932995
,
0.442674
,
0.0321735
,
0.484833
,
0.045935
,
0.21276
,
0.556421
,
0.131825
,
0.285626
,
0.741409
,
0.257467
,
0.975958
,
0.444006
,
0.114553
]]
data_loss
=
[
0.9876687
]
class
NaiveModelTest
(
unittest
.
TestCase
):
def
test_model
(
self
):
start_prog
=
fluid
.
Program
()
main_prog
=
fluid
.
Program
()
start_prog
.
random_seed
=
100
main_prog
.
random_seed
=
100
with
fluid
.
program_guard
(
main_prog
,
start_prog
):
a
=
fluid
.
layers
.
data
(
name
=
"a"
,
shape
=
[
1
,
20
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
10
],
dtype
=
'float32'
)
a1
=
fluid
.
layers
.
fc
(
input
=
a
,
size
=
10
,
act
=
None
,
bias_attr
=
False
)
cost
=
fluid
.
layers
.
square_error_cost
(
a1
,
label
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
optimizer
.
minimize
(
avg_cost
)
x86_place
=
lite
.
Place
(
lite
.
TargetType
.
kX86
,
lite
.
PrecisionType
.
kFloat
,
lite
.
DataLayoutType
.
kNCHW
,
0
)
host_place
=
lite
.
Place
(
lite
.
TargetType
.
kHost
,
lite
.
PrecisionType
.
kFloat
,
lite
.
DataLayoutType
.
kNCHW
,
0
)
scope
=
lite
.
Scope
()
trainer
=
lite
.
CXXTrainer
(
scope
,
x86_place
,
[
x86_place
,
host_place
])
trainer
.
run_startup_program
(
start_prog
.
desc
)
cpu
=
fluid
.
core
.
CPUPlace
()
main_prog
=
add_feed_fetch_op
(
main_prog
,
feed
=
[
'a'
,
'label'
],
fetch_list
=
{
avg_cost
},
scope
=
scope
,
place
=
cpu
)
# print(main_prog)
exe
=
trainer
.
build_main_program_executor
(
main_prog
.
desc
)
feed_data
=
[
_as_lodtensor
(
np
.
array
(
data_a
,
object
),
cpu
),
_as_lodtensor
(
np
.
array
(
data_label
,
object
),
cpu
)
]
exe
.
run
(
feed_data
)
# print(np.array(exe.get_output(0).raw_tensor()))
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
data_loss
),
np
.
array
(
exe
.
get_output
(
0
).
raw_tensor
()),
atol
=
1e-8
),
"lite result not equel to offline result"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
paddle/fluid/lite/tools/build.sh
浏览文件 @
900b4cdd
...
...
@@ -112,6 +112,26 @@ function build_test_server {
test_lite
$TESTS_FILE
}
function
build_test_train
{
mkdir
-p
./build
cd
./build
export
LD_LIBRARY_PATH
=
"
$LD_LIBRARY_PATH
:/paddle/build/third_party/install/mklml/lib"
prepare_workspace
# fake an empty __generated_code__.cc to pass cmake.
cmake ..
-DWITH_LITE
=
ON
-DWITH_GPU
=
OFF
-DWITH_PYTHON
=
ON
-DLITE_WITH_X86
=
ON
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK
=
OFF
-DWITH_TESTING
=
ON
-DWITH_MKL
=
OFF
make test_gen_code_lite
-j
$NUM_CORES_FOR_COMPILE
make test_cxx_api_lite
-j
$NUM_CORES_FOR_COMPILE
ctest
-R
test_cxx_api_lite
ctest
-R
test_gen_code_lite
make test_generated_code
-j
$NUM_CORES_FOR_COMPILE
make
-j
$NUM_CORES_FOR_COMPILE
find
-name
"*.whl"
| xargs pip2
install
python ../paddle/fluid/lite/python/lite_test.py
}
# test_arm_android <some_test_name> <adb_port_number>
function
test_arm_android
{
local
test_name
=
$1
...
...
@@ -543,6 +563,10 @@ function main {
build_test_server
shift
;;
build_test_train
)
build_test_train
shift
;;
build_test_arm
)
build_test_arm
shift
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
900b4cdd
set
(
PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wrapper nccl_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer scope_pool
tracer analysis_predictor imperative_profiler nccl_context
)
message
(
STATUS
"use
${
x86_kernels
}
"
)
message
(
STATUS
"use
${
ops_lite
}
"
)
if
(
WITH_PYTHON
)
cc_library
(
bind_executor_lite SRCS executor_lite.cc DEPS pybind framework_proto
)
set
(
PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wrapper nccl_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer scope_pool bind_executor_lite cxx_api_lite scope_lite
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
mir_passes kernel_lite op_lite optimizer_lite
tracer analysis_predictor imperative_profiler nccl_context
)
endif
(
WITH_PYTHON
)
if
(
WITH_PYTHON
)
list
(
APPEND PYBIND_DEPS py_func_op
)
...
...
paddle/fluid/pybind/executor_lite.cc
0 → 100644
浏览文件 @
900b4cdd
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/executor_lite.h"
#include <pybind11/stl.h>
#include <memory>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/paddle_use_passes.h"
#include "paddle/fluid/lite/core/hvy_tensor.h"
#include "paddle/fluid/lite/core/scope.h"
#include "pybind11/pybind11.h"
namespace
lt
=
paddle
::
lite
;
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
void
BindTensor
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
lt
::
TensorHvy
>
(
*
m
,
"Tensor"
)
.
def
(
pybind11
::
init
<>
())
.
def
(
"raw_tensor"
,
[](
lt
::
TensorHvy
&
self
)
{
return
self
.
raw_tensor
();
})
.
def
(
"share_data_with"
,
[](
lt
::
TensorHvy
&
self
,
const
framework
::
Tensor
&
other
)
{
self
.
ShareDataWith
(
other
);
});
}
void
BindVariable
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
lt
::
Variable
>
(
*
m
,
"Variable"
)
.
def
(
"get_mutable_tensor"
,
[](
lt
::
Variable
&
self
)
{
return
self
.
GetMutable
<
lt
::
Tensor
>
();
})
.
def
(
"get_mutable_fetch_list"
,
[](
lt
::
Variable
&
self
)
->
paddle
::
lite
::
FeedFetchList
*
{
return
self
.
GetMutable
<
paddle
::
lite
::
FeedFetchList
>
();
},
py
::
return_value_policy
::
reference
);
}
void
BindScope
(
pybind11
::
module
*
m
)
{
py
::
class_
<
lt
::
Scope
,
std
::
shared_ptr
<
lt
::
Scope
>>
(
*
m
,
"Scope"
)
.
def
(
pybind11
::
init
<>
())
.
def
(
"new_scope"
,
[](
lt
::
Scope
&
self
)
->
lt
::
Scope
*
{
return
&
self
.
NewScope
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"var"
,
&
lt
::
Scope
::
Var
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"find_var"
,
&
lt
::
Scope
::
FindVar
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"find_local_var"
,
&
lt
::
Scope
::
FindLocalVar
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"parent"
,
&
lt
::
Scope
::
parent
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"local_var_names"
,
&
lt
::
Scope
::
LocalVarNames
,
pybind11
::
return_value_policy
::
reference
);
}
void
BindExecutorLite
(
pybind11
::
module
*
m
)
{
py
::
class_
<
lt
::
Predictor
>
(
*
m
,
"Predictor"
)
.
def
(
pybind11
::
init
<>
())
.
def
(
"__init__"
,
[](
lt
::
Predictor
&
self
,
const
std
::
shared_ptr
<
lt
::
Scope
>&
root_scope
)
{
new
(
&
self
)
lt
::
Predictor
(
root_scope
);
})
.
def
(
"get_input"
,
&
lt
::
Predictor
::
GetInput
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"get_output"
,
&
lt
::
Predictor
::
GetOutput
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"run"
,
[](
lt
::
Predictor
&
self
)
{
self
.
Run
();
})
.
def
(
"run"
,
[](
lt
::
Predictor
&
self
,
const
std
::
vector
<
framework
::
Tensor
>&
tensors
)
{
self
.
Run
(
tensors
);
});
}
void
BindEnums
(
pybind11
::
module
*
m
)
{
py
::
enum_
<
lt
::
TargetType
>
(
*
m
,
"TargetType"
,
py
::
arithmetic
(),
"TargetType enum"
)
.
value
(
"kUnk"
,
lt
::
TargetType
::
kUnk
)
.
value
(
"kHost"
,
lt
::
TargetType
::
kHost
)
.
value
(
"kX86"
,
lt
::
TargetType
::
kX86
)
.
value
(
"kCUDA"
,
lt
::
TargetType
::
kCUDA
)
.
value
(
"kARM"
,
lt
::
TargetType
::
kARM
)
.
value
(
"kAny"
,
lt
::
TargetType
::
kAny
)
.
value
(
"NUM"
,
lt
::
TargetType
::
NUM
);
py
::
enum_
<
lt
::
PrecisionType
>
(
*
m
,
"PrecisionType"
,
py
::
arithmetic
(),
"PrecisionType enum"
)
.
value
(
"kUnk"
,
lt
::
PrecisionType
::
kUnk
)
.
value
(
"kFloat"
,
lt
::
PrecisionType
::
kFloat
)
.
value
(
"kInt8"
,
lt
::
PrecisionType
::
kInt8
)
.
value
(
"kAny"
,
lt
::
PrecisionType
::
kAny
)
.
value
(
"NUM"
,
lt
::
PrecisionType
::
NUM
);
py
::
enum_
<
lt
::
DataLayoutType
>
(
*
m
,
"DataLayoutType"
,
py
::
arithmetic
(),
"DataLayoutType enum"
)
.
value
(
"kUnk"
,
lt
::
DataLayoutType
::
kUnk
)
.
value
(
"kNCHW"
,
lt
::
DataLayoutType
::
kNCHW
)
.
value
(
"kAny"
,
lt
::
DataLayoutType
::
kAny
)
.
value
(
"NUM"
,
lt
::
DataLayoutType
::
NUM
);
}
void
BindPlace
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
lt
::
Place
,
std
::
shared_ptr
<
lt
::
Place
>>
(
*
m
,
"Place"
)
.
def
(
pybind11
::
init
<>
())
.
def
(
"__init__"
,
[](
lt
::
Place
&
self
,
lt
::
TargetType
target
,
lt
::
PrecisionType
precision
,
lt
::
DataLayoutType
layout
,
int16_t
device
)
{
new
(
&
self
)
lt
::
Place
(
target
,
precision
,
layout
,
device
);
})
.
def
(
"is_valid"
,
&
lt
::
Place
::
is_valid
,
pybind11
::
return_value_policy
::
reference
);
}
void
BindCXXTrainer
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
lt
::
CXXTrainer
,
std
::
shared_ptr
<
lt
::
CXXTrainer
>>
(
*
m
,
"CXXTrainer"
)
.
def
(
"__init__"
,
[](
lt
::
CXXTrainer
&
self
,
const
std
::
shared_ptr
<
lt
::
Scope
>&
root_scope
,
const
lt
::
Place
&
preferred_place
,
const
std
::
vector
<
lt
::
Place
>&
valid_places
)
{
new
(
&
self
)
lt
::
CXXTrainer
(
root_scope
,
preferred_place
,
valid_places
);
})
.
def
(
"build_main_program_executor"
,
[](
lt
::
CXXTrainer
&
self
,
framework
::
ProgramDesc
&
desc
)
->
lt
::
Predictor
&
{
return
self
.
BuildMainProgramExecutor
(
desc
);
},
pybind11
::
return_value_policy
::
reference
)
.
def
(
"run_startup_program"
,
[](
lt
::
CXXTrainer
&
self
,
framework
::
ProgramDesc
&
desc
)
{
return
self
.
RunStartupProgram
(
desc
);
});
}
void
BindLite
(
pybind11
::
module
*
m
)
{
BindTensor
(
m
);
BindVariable
(
m
);
BindScope
(
m
);
BindExecutorLite
(
m
);
BindEnums
(
m
);
BindPlace
(
m
);
BindCXXTrainer
(
m
);
}
}
// namespace pybind
}
// namespace paddle
// USE_LITE_OP(mul);
USE_LITE_OP
(
elementwise_sub
);
USE_LITE_OP
(
uniform_random
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
fill_constant
);
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
mul_grad
);
USE_LITE_OP
(
mean
);
USE_LITE_OP
(
square
);
USE_LITE_OP
(
sgd
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL
(
uniform_random
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
fill_constant
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mul
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mul_grad
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
square
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mean
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
sgd
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_sub
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_sub_grad
,
kX86
,
kFloat
,
kNCHW
,
def
);
#endif
paddle/fluid/pybind/executor_lite.h
0 → 100644
浏览文件 @
900b4cdd
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <Python.h>
#include "pybind11/pybind11.h"
namespace
paddle
{
namespace
pybind
{
void
BindLite
(
pybind11
::
module
*
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
900b4cdd
...
...
@@ -54,6 +54,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/executor_lite.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
...
...
@@ -366,6 +367,7 @@ PYBIND11_MODULE(core, m) {
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
int8_t
>
)
#endif
.
def
(
"shape"
,
[](
Tensor
&
self
)
{
return
vectorize
(
self
.
dims
());
})
.
def
(
"memory_size"
,
[](
Tensor
&
self
)
{
return
self
.
memory_size
();
})
.
def
(
"_set_float_element"
,
TensorSetElement
<
float
>
)
.
def
(
"_get_float_element"
,
TensorGetElement
<
float
>
)
.
def
(
"_set_double_element"
,
TensorSetElement
<
double
>
)
...
...
@@ -1528,6 +1530,9 @@ All parameter, weight, gradient are variables in Paddle.
BindNode
(
&
m
);
BindInferenceApi
(
&
m
);
BindDataset
(
&
m
);
py
::
module
lite
=
m
.
def_submodule
(
"lite"
,
"submodule lite"
);
BindLite
(
&
lite
);
}
}
// namespace pybind
}
// namespace paddle
python/paddle/fluid/__init__.py
浏览文件 @
900b4cdd
...
...
@@ -65,6 +65,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from
.
import
install_check
from
.dygraph.nn
import
*
from
.dygraph.layers
import
*
from
.cxx_trainer
import
*
Tensor
=
LoDTensor
...
...
python/paddle/fluid/backward.py
浏览文件 @
900b4cdd
...
...
@@ -71,6 +71,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
op_desc
.
set_block_attr
(
name
,
val
.
desc
)
else
:
op_desc
.
_set_attr
(
name
,
val
)
op_desc
.
check_attrs
()
return
op_desc
...
...
python/paddle/fluid/cxx_trainer.py
0 → 100644
浏览文件 @
900b4cdd
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.
import
core
from
.
import
framework
from
.
import
executor
from
.
import
compiler
import
sys
from
.framework
import
default_main_program
,
Variable
__all__
=
[
'add_feed_fetch_op'
]
def
_has_feed_operators
(
block
,
feed_targets
,
feed_holder_name
):
""" Check whether the block already has feed operators.
Return false if the block does not have any feed operators.
If some feed operators have been prepended to the block, check that
the info contained in these feed operators matches the feed_targets
and feed_holder_name. Raise exception when any mismatch is found.
Return true when the block has feed operators with matching info.
Args:
block: a block instance (typically global block of a program)
feed_targets: a dictionary of {feed_target_name: feed_target_data}
feed_holder_name: the name of the variable that holds the data of
all feed targets. The type of this feed_holder variable is
FEED_MINIBATCH, which is essentially vector<LoDTensor>.
Returns:
A boolean value that indicates whether a block has feed operators
that match the info contained in feed_targets and feed_holder_name.
"""
feed_count
=
0
for
op
in
block
.
ops
:
if
op
.
desc
.
type
()
==
'feed'
:
feed_count
+=
1
assert
op
.
desc
.
input
(
'X'
)[
0
]
==
feed_holder_name
feed_target_name
=
op
.
desc
.
output
(
'Out'
)[
0
]
if
feed_target_name
not
in
feed_targets
:
raise
Exception
(
"'feed_targets' does not have {} variable"
.
format
(
feed_target_name
))
else
:
break
if
feed_count
>
0
and
feed_count
!=
len
(
feed_targets
):
raise
Exception
(
"Feed operators in program desc do not match 'feed_targets'"
)
return
feed_count
>
0
def
_has_fetch_operators
(
block
,
fetch_targets
,
fetch_holder_name
):
""" Check whether the block already has fetch operators.
Return false if the block does not have any fetch operators.
If some fetch operators have been appended to the block, check that
the info contained in these fetch operators matches the fetch_targets
and fetch_holder_name. Raise exception when any mismatch is found.
Return true when the block has fetch operators with matching info.
Args:
block: a block instance (typically global block of a program)
fetch_targets: a dictionary of {fetch_target_name: fetch_target_data}
fetch_holder_name: the name of the variable that holds the data of
all fetch targets. The type of this fetch_holder variable is
FETCH_LIST, which is essentially vector<LoDTensor>.
Return:
A boolean value that indicates whether a block has fetch operators
that match the info contained in fetch_targets and fetch_holder_name.
"""
fetch_count
=
0
for
op
in
block
.
ops
:
if
op
.
desc
.
type
()
==
'fetch'
:
fetch_count
+=
1
assert
op
.
desc
.
output
(
'Out'
)[
0
]
==
fetch_holder_name
fetch_target_name
=
op
.
desc
.
input
(
'X'
)[
0
]
if
fetch_target_name
not
in
[
var
.
desc
.
name
()
for
var
in
fetch_targets
]:
raise
Exception
(
"'fetch_targets' does not have {} variable"
.
format
(
fetch_target_name
))
idx
=
op
.
desc
.
attr
(
'col'
)
assert
fetch_target_name
==
fetch_targets
[
idx
].
desc
.
name
()
if
fetch_count
>
0
and
fetch_count
!=
len
(
fetch_targets
):
raise
Exception
(
"Fetch operators in program desc do not match 'fetch_targets'"
)
return
fetch_count
>
0
def
_add_feed_fetch_ops
(
program
,
feed
,
fetch_list
,
feed_var_name
=
'feed'
,
fetch_var_name
=
'fetch'
):
tmp_program
=
program
.
clone
()
global_block
=
tmp_program
.
global_block
()
if
feed_var_name
in
global_block
.
vars
:
feed_var
=
global_block
.
var
(
feed_var_name
)
else
:
feed_var
=
global_block
.
create_var
(
name
=
feed_var_name
,
type
=
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
,
persistable
=
True
)
if
fetch_var_name
in
global_block
.
vars
:
fetch_var
=
global_block
.
var
(
fetch_var_name
)
else
:
fetch_var
=
global_block
.
create_var
(
name
=
fetch_var_name
,
type
=
core
.
VarDesc
.
VarType
.
FETCH_LIST
,
persistable
=
True
)
# prepend feed operators
if
not
_has_feed_operators
(
global_block
,
feed
,
feed_var_name
):
for
i
,
name
in
enumerate
(
feed
):
out
=
global_block
.
var
(
name
)
global_block
.
_prepend_op
(
type
=
'feed'
,
inputs
=
{
'X'
:
[
feed_var
]},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'col'
:
i
})
# append fetch_operators
if
not
_has_fetch_operators
(
global_block
,
fetch_list
,
fetch_var_name
):
for
i
,
var
in
enumerate
(
fetch_list
):
assert
isinstance
(
var
,
Variable
)
or
isinstance
(
var
,
six
.
string_types
),
(
"Wrong type for fetch_list[%s]: %s"
%
(
i
,
type
(
var
)))
global_block
.
append_op
(
type
=
'fetch'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
fetch_var
]},
attrs
=
{
'col'
:
i
})
return
tmp_program
def
add_feed_fetch_op
(
program
,
feed
,
fetch_list
,
scope
,
place
):
if
program
is
None
:
program
=
default_main_program
()
program
=
_add_feed_fetch_ops
(
program
=
program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
return
program
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录