Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1fb93746
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1fb93746
编写于
4月 23, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
correct the running logic of host model
- make the target wrapper for host works - code clean
上级
ca629eb4
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
193 addition
and
54 deletion
+193
-54
paddle/fluid/lite/CMakeLists.txt
paddle/fluid/lite/CMakeLists.txt
+1
-0
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+1
-1
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+1
-1
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+15
-1
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+7
-0
paddle/fluid/lite/core/memory.h
paddle/fluid/lite/core/memory.h
+15
-17
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+0
-3
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-0
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+6
-0
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+41
-3
paddle/fluid/lite/core/tensor.h
paddle/fluid/lite/core/tensor.h
+3
-1
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+33
-3
paddle/fluid/lite/cuda/target_wrapper.cc
paddle/fluid/lite/cuda/target_wrapper.cc
+4
-2
paddle/fluid/lite/host/CMakeLists.txt
paddle/fluid/lite/host/CMakeLists.txt
+1
-0
paddle/fluid/lite/host/target_wrapper.cc
paddle/fluid/lite/host/target_wrapper.cc
+33
-0
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+3
-1
paddle/fluid/lite/kernels/host/fetch_compute.cc
paddle/fluid/lite/kernels/host/fetch_compute.cc
+1
-1
paddle/fluid/lite/kernels/host/mul_compute.cc
paddle/fluid/lite/kernels/host/mul_compute.cc
+15
-14
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+4
-4
paddle/fluid/lite/model_parser/model_parser.cc
paddle/fluid/lite/model_parser/model_parser.cc
+1
-1
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+0
-1
paddle/fluid/lite/utils/macros.h
paddle/fluid/lite/utils/macros.h
+7
-0
未找到文件。
paddle/fluid/lite/CMakeLists.txt
浏览文件 @
1fb93746
add_subdirectory
(
core
)
add_subdirectory
(
x86
)
add_subdirectory
(
host
)
add_subdirectory
(
cuda
)
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
...
...
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
1fb93746
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite
target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
1fb93746
...
...
@@ -58,7 +58,7 @@ class Predictor {
const
Tensor
*
GetOutput
(
size_t
offset
)
{
auto
*
_fetch_list
=
program_
->
exec_scope
()
->
FindVar
(
"fetch"
);
CHECK
(
_fetch_list
)
<<
"no fatch variable in exec_scope"
;
auto
fetch_list
=
_fetch_list
->
Get
<
std
::
vector
<
Tensor
>>
();
auto
&
fetch_list
=
*
_fetch_list
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
CHECK_LT
(
offset
,
fetch_list
.
size
())
<<
"offset "
<<
offset
<<
" overflow"
;
return
&
fetch_list
.
at
(
offset
);
}
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
1fb93746
...
...
@@ -28,8 +28,22 @@ TEST(CXXApi, test) {
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
({
100
,
100
});
input_tensor
->
mutable_data
<
float
>
();
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
100
*
100
;
i
++
)
{
data
[
i
]
=
i
;
}
LOG
(
INFO
)
<<
"input "
<<
input_tensor
;
LOG
(
INFO
)
<<
"input "
<<
*
input_tensor
;
predictor
.
Run
();
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
memory_size
();
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
0
];
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
1
];
LOG
(
INFO
)
<<
"dims "
<<
out
->
dims
();
LOG
(
INFO
)
<<
"out "
<<
*
out
;
}
}
// namespace lite
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
1fb93746
...
...
@@ -65,6 +65,13 @@ class KernelBase {
virtual
~
KernelBase
()
=
default
;
std
::
string
DebugString
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
":"
<<
TargetToStr
(
target
())
<<
"/"
<<
PrecisionToStr
(
precision
())
<<
"/"
<<
DataLayoutToStr
(
layout
());
return
ss
.
str
();
}
protected:
std
::
unique_ptr
<
KernelContext
>
context_
;
mutable
operators
::
param_t
param_
;
...
...
paddle/fluid/lite/core/memory.h
浏览文件 @
1fb93746
...
...
@@ -21,18 +21,16 @@ namespace lite {
static
void
*
TargetMalloc
(
TargetType
target
,
size_t
size
)
{
void
*
data
{
nullptr
};
switch
(
static_cast
<
int
>
(
target
))
{
case
static_cast
<
int
>
(
TargetType
::
kX86
):
data
=
TargetWrapper
<
TARGET
(
kX86
)
>::
Malloc
(
size
);
switch
(
target
)
{
case
TargetType
::
kHost
:
case
TargetType
::
kX86
:
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
break
;
case
static_cast
<
int
>
(
TargetType
::
kCUDA
)
:
case
TargetType
::
kCUDA
:
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
break
;
case
static_cast
<
int
>
(
TargetType
::
kHost
):
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
break
;
default:
LOG
(
FATAL
)
<<
"Unknown
type"
;
LOG
(
FATAL
)
<<
"Unknown
supported target "
<<
TargetToStr
(
target
)
;
}
return
data
;
}
...
...
@@ -52,17 +50,19 @@ static void TargetFree(TargetType target, void* data) {
static
void
TargetCopy
(
TargetType
target
,
void
*
dst
,
const
void
*
src
,
size_t
size
)
{
switch
(
static_cast
<
int
>
(
target
)
)
{
case
static_cast
<
int
>
(
TargetType
::
kX86
)
:
case
static_cast
<
int
>
(
TargetType
::
kHost
)
:
switch
(
target
)
{
case
TargetType
::
kX86
:
case
TargetType
::
kHost
:
TargetWrapper
<
TARGET
(
kHost
)
>::
MemcpySync
(
dst
,
src
,
size
,
IoDirection
::
DtoD
);
break
;
case
static_cast
<
int
>
(
TargetType
::
kCUDA
)
:
case
TargetType
::
kCUDA
:
TargetWrapper
<
TARGET
(
kCUDA
)
>::
MemcpySync
(
dst
,
src
,
size
,
IoDirection
::
DtoD
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
...
...
@@ -79,12 +79,10 @@ class Buffer {
void
ResetLazy
(
TargetType
target
,
size_t
size
)
{
if
(
target
!=
target_
||
space_
<
size
)
{
Free
();
data_
=
TargetMalloc
(
target
,
size
);
target_
=
target
;
space_
=
size
;
}
if
(
size
<
space_
)
return
;
target_
=
target
;
data_
=
TargetMalloc
(
target
,
size
);
space_
=
size
;
}
void
ResizeLazy
(
size_t
size
)
{
ResetLazy
(
target_
,
size
);
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
1fb93746
...
...
@@ -60,9 +60,6 @@ class SSAGraph : GraphBase {
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
for
(
auto
&
kernel
:
kernels
)
{
op
->
AttachKernel
(
kernel
.
get
());
}
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
,
op
->
op_info
());
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
1fb93746
...
...
@@ -29,6 +29,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
(
kernel_type
.
empty
()
?
op_type_
:
kernel_type
),
place
.
target
,
place
.
precision
);
for
(
auto
&&
it
:
ks
)
{
AttachKernel
(
it
.
get
());
kernels
.
emplace_back
(
std
::
move
(
it
));
}
}
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
1fb93746
...
...
@@ -105,6 +105,11 @@ struct Instruction {
void
Run
()
{
CHECK
(
op_
);
CHECK
(
kernel_
);
LOG
(
INFO
)
<<
"running kernel> "
<<
kernel_
->
DebugString
();
if
(
UNLIKELY
(
first_epoch_
))
{
first_epoch_
=
false
;
op_
->
CheckShape
();
}
op_
->
InferShape
();
kernel_
->
Run
();
}
...
...
@@ -112,6 +117,7 @@ struct Instruction {
private:
std
::
shared_ptr
<
OpLite
>
op_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
bool
first_epoch_
{
true
};
};
/*
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
1fb93746
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <iostream>
#include <sstream>
...
...
@@ -138,11 +139,48 @@ class TargetWrapper {
static
void
StreamSync
(
const
stream_t
&
stream
)
{}
static
void
*
Malloc
(
size_t
size
)
{
return
new
char
[
size
];
}
static
void
Free
(
void
*
ptr
)
{
delete
[]
static_cast
<
char
*>
(
ptr
);
}
static
void
*
Malloc
(
size_t
size
)
{
LOG
(
FATAL
)
<<
"Unimplemented malloc for "
<<
TargetToStr
(
Target
);
return
nullptr
;
}
static
void
Free
(
void
*
ptr
)
{
LOG
(
FATAL
)
<<
"Unimplemented"
;
}
static
void
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
)
{
LOG
(
FATAL
)
<<
"Unimplemented"
;
}
static
void
MemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
,
const
stream_t
&
stream
)
{
MemcpySync
(
dst
,
src
,
size
,
dir
);
}
};
// This interface should be specified by each kind of target.
template
<
>
class
TargetWrapper
<
TARGET
(
kHost
)
>
{
public:
using
stream_t
=
int
;
using
event_t
=
int
;
static
size_t
num_devices
()
{
return
0
;
}
static
size_t
maximum_stream
()
{
return
0
;
}
static
void
CreateStream
(
stream_t
*
stream
)
{}
static
void
DestroyStream
(
const
stream_t
&
stream
)
{}
static
void
CreateEvent
(
event_t
*
event
)
{}
static
void
DestroyEvent
(
const
event_t
&
event
)
{}
static
void
RecordEvent
(
const
event_t
&
event
)
{}
static
void
SyncEvent
(
const
event_t
&
event
)
{}
static
void
StreamSync
(
const
stream_t
&
stream
)
{}
static
void
*
Malloc
(
size_t
size
);
static
void
Free
(
void
*
ptr
);
static
void
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
)
{}
IoDirection
dir
)
;
static
void
MemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
,
const
stream_t
&
stream
)
{
MemcpySync
(
dst
,
src
,
size
,
dir
);
...
...
paddle/fluid/lite/core/tensor.h
浏览文件 @
1fb93746
...
...
@@ -95,13 +95,15 @@ class Tensor {
dims_
=
other
.
dims_
;
target_
=
other
.
target_
;
lod_
=
other
.
lod_
;
memory_size_
=
other
.
memory_size_
;
}
void
CopyDataFrom
(
const
Tensor
&
other
)
{
dims_
=
other
.
dims_
;
target_
=
other
.
target_
;
lod_
=
other
.
lod_
;
*
buffer_
=
*
other
.
buffer_
;
memory_size_
=
other
.
memory_size_
;
buffer_
->
CopyDataFrom
(
*
other
.
buffer_
,
memory_size_
);
}
TargetType
target
()
const
{
return
target_
;
}
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
1fb93746
...
...
@@ -38,9 +38,39 @@ namespace lite {
// The DNN system is simple, and the architecture can not process that many data
// types as a compiler, or that will turn out to a chaos.
//
// We should make sure that supported data types should be registered here, and
// keep the quantity small. And avoid using some special data types as op's IO,
// such as some runtime cache, that need to be avoided.
// We should make sure that the supported data types be registered here, and
// keep the quantity small and avoid using some special data types as op's
// inputs or outputs, such as some runtime cache, those types can't be processed
// by the MIR.
//
// A tensor with different places(target, precision, data layout or device)
// should be treated as different types. Different types might be compatible
// with each other, for example, the `VoidTy` means any type, so any other types
// can be treated as a `VoidTy`.
//
// The Different Types can transform to others by adding some special
// transforming operators, for example, a DataLayoutTransformOp can convert a
// `TensorFp32NCHWTy` to a `TensorFp32NHWCTy`; a IoCopyOp can convert a
// `TensorFp32NCHWTy(kHost)` to `TensorFp32NCHWTy(kCUDA)`. There are many other
// convertions between different Types, but there are some unsupportted type
// convertions, for example, there is noway to convert a `UnsupportedTy` to a
// `TensorAnyTy`.
//
// We use Types to declare the definition of a kernel, each inputs' and outputs'
// arguments have a specific Types.
//
// REGISTER_LITE_KERNEL(mul, kHost, kFloat,
// paddle::lite::kernels::host::MulCompute, def)
// .BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
// TARGET(kHost))})
// .BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
// TARGET(kHost))})
// .BindOutput("Out",
// {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(TARGET(kHost))})
// .Finalize();
//
// The above definition will be used in MIR by Type inference and uncompatible
// types check.
//
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system.
...
...
paddle/fluid/lite/cuda/target_wrapper.cc
浏览文件 @
1fb93746
...
...
@@ -26,12 +26,14 @@ using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
template
<
>
void
*
TargetW
::
Malloc
(
size_t
size
)
{
return
new
char
[
size
];
void
*
ptr
{};
CHECK_EQ
(
cudaSuccess
,
cudaMalloc
(
&
ptr
,
size
));
return
ptr
;
}
template
<
>
void
TargetW
::
Free
(
void
*
ptr
)
{
delete
[]
static_cast
<
char
*>
(
ptr
);
CHECK_EQ
(
cudaSuccess
,
cudaFree
(
ptr
)
);
}
template
<
>
...
...
paddle/fluid/lite/host/CMakeLists.txt
0 → 100644
浏览文件 @
1fb93746
cc_library
(
target_wrapper_host SRCS target_wrapper.cc
)
paddle/fluid/lite/host/target_wrapper.cc
0 → 100644
浏览文件 @
1fb93746
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/target_wrapper.h"
#include <cstring>
namespace
paddle
{
namespace
lite
{
void
*
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size_t
size
)
{
return
new
char
[
size
];
}
void
TargetWrapper
<
TARGET
(
kHost
)
>::
Free
(
void
*
ptr
)
{
delete
[]
static_cast
<
char
*>
(
ptr
);
}
void
TargetWrapper
<
TARGET
(
kHost
)
>::
MemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
IoDirection
dir
)
{
memcpy
(
dst
,
src
,
size
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
1fb93746
...
...
@@ -28,6 +28,8 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
auto
&
param
=
Param
<
operators
::
FeedParam
>
();
const
Tensor
&
feed_item
=
param
.
feed_list
->
at
(
param
.
col
);
param
.
out
->
CopyDataFrom
(
feed_item
);
LOG
(
INFO
)
<<
"FEED input "
<<
feed_item
<<
" col "
<<
param
.
col
;
LOG
(
INFO
)
<<
"FEED output "
<<
*
param
.
out
;
}
};
...
...
@@ -40,6 +42,6 @@ REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
Tensor
Fp32NCHW
Ty
>
(
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
Tensor
Any
Ty
>
(
TARGET
(
kHost
))})
.
Finalize
();
paddle/fluid/lite/kernels/host/fetch_compute.cc
浏览文件 @
1fb93746
...
...
@@ -32,7 +32,7 @@ class FetchCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
auto
&
dst
=
fetch_list
->
at
(
param
.
col
);
dst
.
CopyDataFrom
(
*
param
.
input
);
dst
.
ShareDataWith
(
*
param
.
input
);
}
};
...
...
paddle/fluid/lite/kernels/host/mul_compute.cc
浏览文件 @
1fb93746
...
...
@@ -40,22 +40,23 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{
auto
&
theparam
=
Param
<
operators
::
MulParam
>
();
core
::
dim2
x_shape
(
{
product
(
theparam
.
x
->
dims
().
begin
(),
theparam
.
x
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
),
product
(
theparam
.
x
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
,
theparam
.
x
->
dims
().
end
())});
auto
&
param
=
Param
<
operators
::
MulParam
>
();
core
::
dim2
x_shape
({
product
(
param
.
x
->
dims
().
begin
(),
param
.
x
->
dims
().
begin
()
+
param
.
x_num_col_dims
),
product
(
param
.
x
->
dims
().
begin
()
+
param
.
x_num_col_dims
,
param
.
x
->
dims
().
end
())});
core
::
dim2
y_shape
(
{
product
(
theparam
.
y
->
dims
().
begin
(),
theparam
.
y
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
),
product
(
theparam
.
y
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
,
theparam
.
y
->
dims
().
end
())});
core
::
dim2
y_shape
({
product
(
param
.
y
->
dims
().
begin
(),
param
.
y
->
dims
().
begin
()
+
param
.
x_num_col_dims
),
product
(
param
.
y
->
dims
().
begin
()
+
param
.
x_num_col_dims
,
param
.
y
->
dims
().
end
())});
mul_compute_eigen
(
theparam
.
x
->
data
<
float
>
(),
x_shape
.
x
,
x_shape
.
y
,
//
theparam
.
y
->
data
<
float
>
(),
y_shape
.
x
,
y_shape
.
y
,
//
theparam
.
output
->
mutable_data
<
float
>
());
mul_compute_eigen
(
param
.
x
->
data
<
float
>
(),
x_shape
.
x
,
x_shape
.
y
,
//
param
.
y
->
data
<
float
>
(),
y_shape
.
x
,
y_shape
.
y
,
//
param
.
output
->
mutable_data
<
float
>
());
LOG
(
INFO
)
<<
"MUL x "
<<
*
param
.
x
;
LOG
(
INFO
)
<<
"MUL W "
<<
*
param
.
y
;
LOG
(
INFO
)
<<
"MUL out "
<<
*
param
.
output
;
}
virtual
~
MulCompute
()
=
default
;
...
...
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
1fb93746
...
...
@@ -36,10 +36,10 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{
auto
&
the
param
=
Param
<
operators
::
ScaleParam
>
();
scale_compute
(
theparam
.
x
->
data
<
float
>
(),
theparam
.
x
->
mutable_data
<
float
>
(),
product
(
theparam
.
x
->
dims
()),
theparam
.
scale
,
the
param
.
bias
,
the
param
.
bias_after_scale
);
auto
&
param
=
Param
<
operators
::
ScaleParam
>
();
scale_compute
(
param
.
x
->
data
<
float
>
(),
param
.
output
->
mutable_data
<
float
>
(),
product
(
param
.
x
->
dims
()),
param
.
scale
,
param
.
bias
,
param
.
bias_after_scale
);
}
virtual
~
ScaleCompute
()
=
default
;
...
...
paddle/fluid/lite/model_parser/model_parser.cc
浏览文件 @
1fb93746
...
...
@@ -77,7 +77,7 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
DO
(
INT64
,
int64_t
);
#undef DO
default:
LOG
(
FATAL
)
<<
"unknown type
"
;
LOG
(
FATAL
)
<<
"unknown type
"
<<
desc
.
data_type
()
;
}
is
.
read
(
static_cast
<
char
*>
(
buf
),
size
);
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
1fb93746
...
...
@@ -51,7 +51,6 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
kernel_
->
SetParam
(
param_
);
return
true
;
}
...
...
paddle/fluid/lite/utils/macros.h
浏览文件 @
1fb93746
...
...
@@ -21,3 +21,10 @@
#endif
#define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented";
#ifndef LIKELY
#define LIKELY(x) __builtin_expect(!!(x), 1)
#endif
#ifndef UNLIKELY
#define UNLIKELY(x) __built_expect(!!(x), 0)
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录