Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f5e64af
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f5e64af
编写于
12月 27, 2017
作者:
Y
Yang Yu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/check_nan_executor' into feature/rnn_gradient_check
上级
87288850
5162c41a
变更
46
展开全部
隐藏空白更改
内联
并排
Showing
46 changed file
with
882 addition
and
435 deletion
+882
-435
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+5
-2
paddle/framework/data_transform.cc
paddle/framework/data_transform.cc
+26
-0
paddle/framework/data_transform.h
paddle/framework/data_transform.h
+109
-0
paddle/framework/data_transform_test.cc
paddle/framework/data_transform_test.cc
+78
-0
paddle/framework/executor.cc
paddle/framework/executor.cc
+27
-6
paddle/framework/init.cc
paddle/framework/init.cc
+1
-1
paddle/framework/operator.cc
paddle/framework/operator.cc
+35
-3
paddle/framework/tensor_util.h
paddle/framework/tensor_util.h
+98
-0
paddle/framework/tensor_util_test.cc
paddle/framework/tensor_util_test.cc
+24
-0
paddle/gserver/layers/MKLDNNLRNLayer.cpp
paddle/gserver/layers/MKLDNNLRNLayer.cpp
+1
-1
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+59
-59
paddle/operators/activation_op.h
paddle/operators/activation_op.h
+208
-179
paddle/operators/array_operator.h
paddle/operators/array_operator.h
+2
-2
paddle/operators/array_to_lod_tensor_op.cc
paddle/operators/array_to_lod_tensor_op.cc
+3
-2
paddle/operators/assign_op.cc
paddle/operators/assign_op.cc
+2
-2
paddle/operators/beam_search_decode_op.cc
paddle/operators/beam_search_decode_op.cc
+2
-2
paddle/operators/cond_op.cc
paddle/operators/cond_op.cc
+2
-2
paddle/operators/feed_op.cc
paddle/operators/feed_op.cc
+2
-2
paddle/operators/fetch_op.cc
paddle/operators/fetch_op.cc
+2
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+2
-2
paddle/operators/fill_op.cc
paddle/operators/fill_op.cc
+3
-2
paddle/operators/load_op.cc
paddle/operators/load_op.cc
+2
-2
paddle/operators/lod_tensor_to_array_op.cc
paddle/operators/lod_tensor_to_array_op.cc
+3
-2
paddle/operators/merge_lod_tensor_op.cc
paddle/operators/merge_lod_tensor_op.cc
+2
-2
paddle/operators/nccl_op_test.cu.cc
paddle/operators/nccl_op_test.cu.cc
+1
-1
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+5
-4
paddle/operators/reorder_lod_tensor_by_rank_op.cc
paddle/operators/reorder_lod_tensor_by_rank_op.cc
+2
-2
paddle/operators/save_op.cc
paddle/operators/save_op.cc
+2
-2
paddle/operators/shrink_rnn_memory_op.cc
paddle/operators/shrink_rnn_memory_op.cc
+2
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+11
-11
paddle/operators/softmax_op.h
paddle/operators/softmax_op.h
+6
-6
paddle/operators/split_lod_tensor_op.cc
paddle/operators/split_lod_tensor_op.cc
+2
-2
paddle/operators/tensor_array_read_write_op.cc
paddle/operators/tensor_array_read_write_op.cc
+6
-4
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+1
-19
paddle/platform/device_context.h
paddle/platform/device_context.h
+22
-8
paddle/platform/device_context_test.cu
paddle/platform/device_context_test.cu
+7
-22
paddle/platform/nccl_test.cu
paddle/platform/nccl_test.cu
+1
-1
paddle/platform/place.h
paddle/platform/place.h
+27
-1
paddle/pybind/tensor_py.h
paddle/pybind/tensor_py.h
+5
-4
python/paddle/v2/fluid/__init__.py
python/paddle/v2/fluid/__init__.py
+1
-1
python/paddle/v2/fluid/io.py
python/paddle/v2/fluid/io.py
+14
-2
python/paddle/v2/fluid/layer_helper.py
python/paddle/v2/fluid/layer_helper.py
+1
-1
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+10
-9
python/paddle/v2/fluid/tests/test_activation_op.py
python/paddle/v2/fluid/tests/test_activation_op.py
+54
-54
python/paddle/v2/fluid/tests/test_net.py
python/paddle/v2/fluid/tests/test_net.py
+2
-2
python/paddle/v2/fluid/tests/test_softmax_op.py
python/paddle/v2/fluid/tests/test_softmax_op.py
+2
-2
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
6f5e64af
...
...
@@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library
(
scope SRCS scope.cc DEPS glog
)
cc_test
(
scope_test SRCS scope_test.cc DEPS scope
)
cc_library
(
data_transform SRCS data_transform.cc DEPS tensor framework_proto
)
cc_test
(
data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_test
(
program_desc_test SRCS program_desc_test.cc DEPS proto_desc
...
...
@@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry init
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog
)
...
...
@@ -64,4 +67,4 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library
(
init SRCS init.cc DEPS gflags device_context place stringpiece
)
cc_test
(
init_test SRCS init_test.cc DEPS init
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context
framework_proto
)
paddle/framework/data_transform.cc
0 → 100644
浏览文件 @
6f5e64af
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_transform.h"
namespace
paddle
{
namespace
framework
{
DataTransformFnMap
&
DataTransformFnMap
::
Instance
()
{
static
DataTransformFnMap
data_transform_map
;
return
data_transform_map
;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/data_transform.h
0 → 100644
浏览文件 @
6f5e64af
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <utility>
#include <vector>
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/macros.h"
namespace
paddle
{
namespace
framework
{
using
DataTransformFN
=
std
::
function
<
void
(
const
std
::
vector
<
platform
::
DeviceContext
*>
ctx
,
const
Variable
&
in
,
Variable
*
out
)
>
;
using
KernelTypePair
=
std
::
pair
<
OpKernelType
,
OpKernelType
>
;
struct
KernelTypePairHash
{
static
void
HashCombine
(
const
OpKernelType
&
t
,
std
::
size_t
*
seed
)
{
OpKernelType
::
Hash
kernel_type_hasher
;
(
*
seed
)
^=
kernel_type_hasher
(
t
)
+
0x9e3779b9
+
(
*
seed
<<
6
)
+
(
*
seed
>>
2
);
}
size_t
operator
()(
const
KernelTypePair
&
kernel_pair
)
const
{
std
::
size_t
seed
=
0
;
HashCombine
(
kernel_pair
.
first
,
&
seed
);
HashCombine
(
kernel_pair
.
second
,
&
seed
);
return
seed
;
}
};
using
DataTransformMap
=
std
::
unordered_map
<
KernelTypePair
,
DataTransformFN
,
KernelTypePairHash
>
;
class
DataTransformFnMap
{
public:
static
DataTransformFnMap
&
Instance
();
bool
Has
(
const
KernelTypePair
&
key_pair
)
const
{
return
map_
.
find
(
key_pair
)
!=
map_
.
end
();
}
void
Insert
(
const
OpKernelType
&
left
,
const
OpKernelType
&
right
,
const
DataTransformFN
&
data_tranform_fn
)
{
Insert
(
std
::
make_pair
(
left
,
right
),
data_tranform_fn
);
}
void
Insert
(
const
KernelTypePair
&
kernel_type_pair
,
const
DataTransformFN
&
data_tranform_fn
)
{
PADDLE_ENFORCE
(
!
Has
(
kernel_type_pair
),
"KernelTypePair %s has been registered"
,
""
);
map_
.
insert
({
kernel_type_pair
,
data_tranform_fn
});
}
const
DataTransformFN
&
Get
(
const
KernelTypePair
&
key_pair
)
const
{
auto
data_transformer
=
GetNullable
(
key_pair
);
PADDLE_ENFORCE_NOT_NULL
(
data_transformer
,
"DataTransformFN should not be NULL"
);
return
*
data_transformer
;
}
const
DataTransformFN
*
GetNullable
(
const
KernelTypePair
&
key_pair
)
const
{
auto
it
=
map_
.
find
(
key_pair
);
if
(
it
==
map_
.
end
())
{
return
nullptr
;
}
else
{
return
&
(
it
->
second
);
}
}
const
DataTransformMap
&
Map
()
const
{
return
map_
;
}
private:
DataTransformFnMap
()
=
default
;
DataTransformMap
map_
;
DISABLE_COPY_AND_ASSIGN
(
DataTransformFnMap
);
};
// generate unique name with __LINE__
// refs https://stackoverflow.com/questions/1597007
#define TOKENPASTE(x, y) x##y
#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
static int TOKENPASTE2(fn_, __LINE__)() { \
::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
return 0; \
} \
static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
TOKENPASTE2(fn_, __LINE__)()
}
// namespace framework
}
// namespace paddle
paddle/framework/data_transform_test.cc
0 → 100644
浏览文件 @
6f5e64af
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_transform.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
using
namespace
platform
;
int
test_value
=
0
;
OpKernelType
kernel_type_1
(
proto
::
DataType
::
FP32
,
CPUPlace
(),
DataLayout
::
kNCHW
,
LibraryType
::
kCUDNN
);
OpKernelType
kernel_type_2
(
proto
::
DataType
::
FP32
,
CUDAPlace
(
0
),
DataLayout
::
kNCHW
,
LibraryType
::
kCUDNN
);
OpKernelType
kernel_type_3
(
proto
::
DataType
::
FP16
,
CUDAPlace
(
0
),
DataLayout
::
kNCHW
,
LibraryType
::
kCUDNN
);
void
type1_to_type2
(
std
::
vector
<
platform
::
DeviceContext
*>
ctx
,
const
Variable
&
in
,
Variable
*
out
)
{
test_value
++
;
}
void
type2_to_type3
(
std
::
vector
<
platform
::
DeviceContext
*>
ctx
,
const
Variable
&
in
,
Variable
*
out
)
{
test_value
--
;
}
void
type1_to_type3
(
std
::
vector
<
platform
::
DeviceContext
*>
ctx
,
const
Variable
&
in
,
Variable
*
out
)
{
test_value
+=
2
;
}
}
// namespace framework
}
// namespace paddle
namespace
frw
=
paddle
::
framework
;
REGISTER_DATA_TRANSFORM_FN
(
frw
::
kernel_type_1
,
frw
::
kernel_type_2
,
frw
::
type1_to_type2
);
REGISTER_DATA_TRANSFORM_FN
(
frw
::
kernel_type_2
,
frw
::
kernel_type_3
,
frw
::
type2_to_type3
);
REGISTER_DATA_TRANSFORM_FN
(
frw
::
kernel_type_1
,
frw
::
kernel_type_3
,
frw
::
type1_to_type3
);
TEST
(
DataTransform
,
Register
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
auto
&
instance
=
DataTransformFnMap
::
Instance
();
ASSERT_EQ
(
instance
.
Map
().
size
(),
3UL
);
std
::
vector
<
DeviceContext
*>
ctx
;
paddle
::
framework
::
Variable
in
;
paddle
::
framework
::
Variable
out
;
instance
.
Get
(
std
::
make_pair
(
frw
::
kernel_type_1
,
frw
::
kernel_type_2
))(
ctx
,
in
,
&
out
);
ASSERT_EQ
(
test_value
,
1
);
instance
.
Get
(
std
::
make_pair
(
frw
::
kernel_type_2
,
frw
::
kernel_type_3
))(
ctx
,
in
,
&
out
);
ASSERT_EQ
(
test_value
,
0
);
instance
.
Get
(
std
::
make_pair
(
frw
::
kernel_type_1
,
frw
::
kernel_type_3
))(
ctx
,
in
,
&
out
);
ASSERT_EQ
(
test_value
,
2
);
}
paddle/framework/executor.cc
浏览文件 @
6f5e64af
...
...
@@ -14,18 +14,17 @@ limitations under the License. */
#include "paddle/framework/executor.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
DEFINE_bool
(
check_nan_inf
,
false
,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely."
);
namespace
paddle
{
namespace
framework
{
...
...
@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}
}
static
void
CheckTensorNANOrInf
(
const
std
::
string
&
name
,
const
framework
::
Tensor
&
tensor
)
{
if
(
tensor
.
type
().
hash_code
()
!=
typeid
(
float
).
hash_code
()
&&
tensor
.
type
().
hash_code
()
!=
typeid
(
double
).
hash_code
())
{
return
;
}
if
(
tensor
.
memory_size
()
==
0
)
{
return
;
}
PADDLE_ENFORCE
(
!
framework
::
HasInf
(
tensor
),
"Tensor %s has Inf"
,
name
);
PADDLE_ENFORCE
(
!
framework
::
HasNAN
(
tensor
),
"Tensor %s has NAN"
,
name
);
}
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
// TODO(tonyyang-svail):
...
...
@@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
VLOG
(
3
)
<<
op
->
DebugString
();
op
->
Run
(
*
local_scope
,
place_
);
if
(
FLAGS_check_nan_inf
)
{
for
(
auto
&
vname
:
op
->
OutputVars
(
true
))
{
auto
*
var
=
local_scope
->
FindVar
(
vname
);
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
LoDTensor
>
());
}
}
}
}
if
(
create_local_scope
)
{
scope
->
DeleteScope
(
local_scope
);
...
...
paddle/framework/init.cc
浏览文件 @
6f5e64af
...
...
@@ -71,7 +71,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
places
.
emplace_back
(
platform
::
CPUPlace
());
LOG
(
WARNING
)
<<
"Not specified CPU device, create CPU by Default."
;
}
platform
::
DeviceContextPool
::
Create
(
places
);
platform
::
DeviceContextPool
::
Init
(
places
);
return
true
;
}
...
...
paddle/framework/operator.cc
浏览文件 @
6f5e64af
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
...
...
@@ -387,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope,
const
platform
::
Place
&
place
)
const
{
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
dev_ctx
=
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
dev_ctx
=
pool
.
Get
(
place
);
// check if op[type] has kernel registered.
auto
&
all_op_kernels
=
AllOpKernels
();
...
...
@@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key
);
}
kernel_iter
->
second
->
Compute
(
ctx
);
if
(
actual_kernel_key
==
expected_kernel_key
)
{
kernel_iter
->
second
->
Compute
(
ctx
);
}
else
{
Scope
&
op_scope
=
scope
.
NewScope
();
auto
input_vars
=
this
->
InputVars
();
for
(
auto
var_name
:
input_vars
)
{
op_scope
.
Var
(
var_name
);
}
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform
::
DeviceContext
*
trans_dev_ctx
=
nullptr
;
std
::
vector
<
platform
::
DeviceContext
*>
trans_dev_ctx_vec
{
trans_dev_ctx
};
// TODO(qijun) get appropriate DataTransformFN from global map
framework
::
DataTransformFN
trans_fun
=
nullptr
;
// Wait for transform starting
dev_ctx
->
Wait
();
for
(
auto
var_name
:
input_vars
)
{
trans_fun
(
trans_dev_ctx_vec
,
*
(
scope
.
FindVar
(
var_name
)),
op_scope
.
FindVar
(
var_name
));
}
// Wait for data transform finishing
for
(
auto
ctx
:
trans_dev_ctx_vec
)
{
ctx
->
Wait
();
}
// Create a new ExecutionContext
ExecutionContext
op_ctx
(
*
this
,
op_scope
,
*
dev_ctx
);
kernel_iter
->
second
->
Compute
(
op_ctx
);
}
}
OpKernelType
OperatorWithKernel
::
GetActualKernelType
(
...
...
paddle/framework/tensor_util.h
浏览文件 @
6f5e64af
...
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -205,5 +208,100 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr
,
size
);
}
template
<
typename
Predicate
,
typename
DevCtx
>
struct
AnyDTypeVisitor
{
Predicate
predicate_
;
const
Tensor
&
tensor_
;
const
DevCtx
&
ctx_
;
Tensor
*
out_
;
AnyDTypeVisitor
(
Predicate
predicate
,
const
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
Tensor
*
out
)
:
predicate_
(
predicate
),
tensor_
(
tensor
),
ctx_
(
ctx
),
out_
(
out
)
{}
template
<
typename
T
>
void
operator
()()
const
{
auto
t
=
EigenVector
<
T
>::
Flatten
(
tensor_
);
auto
o
=
EigenScalar
<
bool
>::
From
(
*
out_
);
o
.
device
(
*
ctx_
.
eigen_device
())
=
predicate_
(
t
).
any
();
}
};
template
<
typename
Predicate
,
typename
DevCtx
>
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
VisitDataType
(
ToDataType
(
tensor
.
type
()),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
predicate
,
tensor
,
ctx
,
out
));
}
template
<
typename
Predicate
>
struct
AnyVisitor
:
public
boost
::
static_visitor
<
bool
>
{
const
framework
::
Tensor
&
tensor_
;
Predicate
predicate_
;
AnyVisitor
(
const
framework
::
Tensor
&
tensor
,
Predicate
predicate
)
:
tensor_
(
tensor
),
predicate_
(
std
::
move
(
predicate
))
{}
template
<
typename
Place
>
bool
operator
()(
const
Place
&
place
)
const
{
framework
::
Tensor
out
;
out
.
Resize
({
1
});
out
.
mutable_data
<
bool
>
(
place
);
auto
*
ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
place
);
AnyImpl
(
predicate_
,
tensor_
,
*
ctx
,
&
out
);
return
this
->
GetResult
(
out
,
place
);
}
bool
GetResult
(
const
framework
::
Tensor
&
out
,
const
platform
::
CUDAPlace
&
gpu
)
const
{
platform
::
CPUPlace
cpu
;
framework
::
Tensor
tmp
;
tmp
.
Resize
({
1
});
tmp
.
mutable_data
<
bool
>
(
cpu
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
gpu
)
->
Wait
();
CopyFrom
(
out
,
cpu
,
&
tmp
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
gpu
)
->
Wait
();
return
GetResult
(
tmp
,
cpu
);
}
bool
GetResult
(
const
framework
::
Tensor
&
out
,
const
platform
::
CPUPlace
&
cpu
)
const
{
return
*
out
.
data
<
bool
>
();
}
};
template
<
typename
Predicate
>
inline
bool
Any
(
const
framework
::
Tensor
&
tensor
,
Predicate
predicate
)
{
AnyVisitor
<
Predicate
>
visitor
(
tensor
,
predicate
);
auto
place
=
tensor
.
place
();
return
platform
::
VisitPlace
(
place
,
visitor
);
}
struct
HasNANPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isnan
())
{
return
eigen_vec
.
isnan
();
}
};
inline
bool
HasNAN
(
const
framework
::
Tensor
&
tensor
)
{
HasNANPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
struct
HasInfPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isinf
())
{
return
eigen_vec
.
isinf
();
}
};
inline
bool
HasInf
(
const
framework
::
Tensor
&
tensor
)
{
HasInfPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/tensor_util_test.cc
浏览文件 @
6f5e64af
...
...
@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
namespace
paddle
{
...
...
@@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) {
#endif
}
TEST
(
IsNAN
,
CPU
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
src
;
float
*
buf
=
src
.
mutable_data
<
float
>
({
3
},
CPUPlace
());
buf
[
0
]
=
0.0
;
buf
[
1
]
=
NAN
;
buf
[
2
]
=
0.0
;
ASSERT_TRUE
(
HasNAN
(
src
));
}
TEST
(
IsInf
,
CPU
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
src
;
double
*
buf
=
src
.
mutable_data
<
double
>
({
3
},
CPUPlace
());
buf
[
0
]
=
1.0
;
buf
[
1
]
=
INFINITY
;
buf
[
2
]
=
0.0
;
ASSERT_TRUE
(
HasInf
(
src
));
}
}
// namespace framework
}
// namespace paddle
paddle/gserver/layers/MKLDNNLRNLayer.cpp
浏览文件 @
6f5e64af
...
...
@@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap,
}
/* the size of inputs for norm-layer is 1 */
CHECK_EQ
(
config_
.
inputs_size
(),
1
UL
);
CHECK_EQ
(
config_
.
inputs_size
(),
1
);
const
NormConfig
&
conf
=
config_
.
inputs
(
0
).
norm_conf
();
localSize_
=
conf
.
size
();
alpha_
=
conf
.
scale
();
...
...
paddle/operators/activation_op.cc
浏览文件 @
6f5e64af
...
...
@@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
"
Y
"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"
Y
"
);
ctx
->
SetOutputDim
(
"
Out
"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"
Out
"
);
}
};
...
...
@@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"
Y
"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"
Out
"
));
}
};
...
...
@@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Sigmoid operator"
);
AddOutput
(
"
Y
"
,
"Output of Sigmoid operator"
);
AddOutput
(
"
Out
"
,
"Output of Sigmoid operator"
);
AddComment
(
R"DOC(
Sigmoid Activation Operator
$$
y
= \frac{1}{1 + e^{-x}}$$
$$
out
= \frac{1}{1 + e^{-x}}$$
)DOC"
);
}
...
...
@@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
LogSigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of LogSigmoid operator"
);
AddOutput
(
"
Y
"
,
"Output of LogSigmoid operator"
);
AddOutput
(
"
Out
"
,
"Output of LogSigmoid operator"
);
AddComment
(
R"DOC(
Logsigmoid Activation Operator
$$
y
= \log \frac{1}{1 + e^{-x}}$$
$$
out
= \log \frac{1}{1 + e^{-x}}$$
)DOC"
);
}
...
...
@@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
ExpOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Exp operator"
);
AddOutput
(
"
Y
"
,
"Output of Exp operator"
);
AddOutput
(
"
Out
"
,
"Output of Exp operator"
);
AddComment
(
R"DOC(
Exp Activation Operator.
$
y
= e^x$
$
out
= e^x$
)DOC"
);
}
...
...
@@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
ReluOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Relu operator"
);
AddOutput
(
"
Y
"
,
"Output of Relu operator"
);
AddOutput
(
"
Out
"
,
"Output of Relu operator"
);
AddComment
(
R"DOC(
Relu Activation Operator.
$
y
= \max(x, 0)$
$
out
= \max(x, 0)$
)DOC"
);
}
...
...
@@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
LeakyReluOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of LeakyRelu operator"
);
AddOutput
(
"
Y
"
,
"Output of LeakyRelu operator"
);
AddOutput
(
"
Out
"
,
"Output of LeakyRelu operator"
);
AddAttr
<
float
>
(
"alpha"
,
"The small negative slope"
).
SetDefault
(
0.02
f
);
AddComment
(
R"DOC(
LeakyRelu Activation Operator.
$
y
= \max(x, \alpha * x)$
$
out
= \max(x, \alpha * x)$
)DOC"
);
}
...
...
@@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
SoftShrinkOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Softshrink operator"
);
AddOutput
(
"
Y
"
,
"Output of Softshrink operator"
);
AddOutput
(
"
Out
"
,
"Output of Softshrink operator"
);
AddAttr
<
float
>
(
"lambda"
,
"non-negative offset"
).
SetDefault
(
0.5
f
);
AddComment
(
R"DOC(
Softshrink Activation Operator.
$$
y
= \begin{cases}
out
= \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
...
...
@@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
TanhOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Tanh operator"
);
AddOutput
(
"
Y
"
,
"Output of Tanh operator"
);
AddOutput
(
"
Out
"
,
"Output of Tanh operator"
);
AddComment
(
R"DOC(
Tanh Activation Operator.
$$
y
= \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
$$
out
= \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC"
);
}
...
...
@@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
TanhShrinkOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of TanhShrink operator"
);
AddOutput
(
"
Y
"
,
"Output of TanhShrink operator"
);
AddOutput
(
"
Out
"
,
"Output of TanhShrink operator"
);
AddComment
(
R"DOC(
TanhShrink Activation Operator.
$$
y
= x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
$$
out
= x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC"
);
}
...
...
@@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
HardShrinkOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of HardShrink operator"
);
AddOutput
(
"
Y
"
,
"Output of HardShrink operator"
);
AddOutput
(
"
Out
"
,
"Output of HardShrink operator"
);
AddAttr
<
float
>
(
"threshold"
,
"The value of threshold for HardShrink"
)
.
SetDefault
(
0.5
f
);
AddComment
(
R"DOC(
HardShrink Activation Operator.
$$
y
= \begin{cases}
out
= \begin{cases}
x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\
0, \text{otherwise}
...
...
@@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
SqrtOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Sqrt operator"
);
AddOutput
(
"
Y
"
,
"Output of Sqrt operator"
);
AddOutput
(
"
Out
"
,
"Output of Sqrt operator"
);
AddComment
(
R"DOC(
Sqrt Activation Operator.
$
y
= \sqrt{x}$
$
out
= \sqrt{x}$
)DOC"
);
}
...
...
@@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
AbsOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Abs operator"
);
AddOutput
(
"
Y
"
,
"Output of Abs operator"
);
AddOutput
(
"
Out
"
,
"Output of Abs operator"
);
AddComment
(
R"DOC(
Abs Activation Operator.
$
y
= |x|$
$
out
= |x|$
)DOC"
);
}
...
...
@@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
CeilOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Ceil operator"
);
AddOutput
(
"
Y
"
,
"Output of Ceil operator"
);
AddOutput
(
"
Out
"
,
"Output of Ceil operator"
);
AddComment
(
R"DOC(
Ceil Activation Operator.
$
y
= ceil(x)$
$
out
= ceil(x)$
)DOC"
);
}
...
...
@@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
FloorOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Floor operator"
);
AddOutput
(
"
Y
"
,
"Output of Floor operator"
);
AddOutput
(
"
Out
"
,
"Output of Floor operator"
);
AddComment
(
R"DOC(
Floor Activation Operator.
$
y
= floor(x)$
$
out
= floor(x)$
)DOC"
);
}
...
...
@@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
RoundOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Round operator"
);
AddOutput
(
"
Y
"
,
"Output of Round operator"
);
AddOutput
(
"
Out
"
,
"Output of Round operator"
);
AddComment
(
R"DOC(
Round Activation Operator.
$
y
= [x]$
$
out
= [x]$
)DOC"
);
}
...
...
@@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
ReciprocalOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Reciprocal operator"
);
AddOutput
(
"
Y
"
,
"Output of Reciprocal operator"
);
AddOutput
(
"
Out
"
,
"Output of Reciprocal operator"
);
AddComment
(
R"DOC(
Reciprocal Activation Operator.
$$
y
= \frac{1}{x}$$
$$
out
= \frac{1}{x}$$
)DOC"
);
}
...
...
@@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker {
LogOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Log operator"
);
AddOutput
(
"
Y
"
,
"Output of Log operator"
);
AddOutput
(
"
Out
"
,
"Output of Log operator"
);
AddComment
(
R"DOC(
Log Activation Operator.
$
y
= \ln(x)$
$
out
= \ln(x)$
Natural logarithm of x.
...
...
@@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
SquareOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Square operator"
);
AddOutput
(
"
Y
"
,
"Output of Square operator"
);
AddOutput
(
"
Out
"
,
"Output of Square operator"
);
AddComment
(
R"DOC(
Square Activation Operator.
$
y
= x^2$
$
out
= x^2$
)DOC"
);
}
...
...
@@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
SoftplusOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Softplus operator"
);
AddOutput
(
"
Y
"
,
"Output of Softplus operator"
);
AddOutput
(
"
Out
"
,
"Output of Softplus operator"
);
AddComment
(
R"DOC(
Softplus Activation Operator.
$
y
= \ln(1 + e^{x})$
$
out
= \ln(1 + e^{x})$
)DOC"
);
}
...
...
@@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
SoftsignOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Softsign operator"
);
AddOutput
(
"
Y
"
,
"Output of Softsign operator"
);
AddOutput
(
"
Out
"
,
"Output of Softsign operator"
);
AddComment
(
R"DOC(
Softsign Activation Operator.
$$
y
= \frac{x}{1 + |x|}$$
$$
out
= \frac{x}{1 + |x|}$$
)DOC"
);
}
...
...
@@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
BReluOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of BRelu operator"
);
AddOutput
(
"
Y
"
,
"Output of BRelu operator"
);
AddOutput
(
"
Out
"
,
"Output of BRelu operator"
);
AddAttr
<
float
>
(
"t_min"
,
"The min marginal value of BRelu"
)
.
SetDefault
(
static_cast
<
float
>
(
0
));
AddAttr
<
float
>
(
"t_max"
,
"The max marginal value of BRelu"
)
...
...
@@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
BRelu Activation Operator.
$
y
= \max(\min(x, t_{min}), t_{max})$
$
out
= \max(\min(x, t_{min}), t_{max})$
)DOC"
);
}
...
...
@@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
SoftReluOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of SoftRelu operator"
);
AddOutput
(
"
Y
"
,
"Output of SoftRelu operator"
);
AddOutput
(
"
Out
"
,
"Output of SoftRelu operator"
);
AddAttr
<
float
>
(
"threshold"
,
"The threshold value of SoftRelu"
)
.
SetDefault
(
40.0
f
);
AddComment
(
R"DOC(
SoftRelu Activation Operator.
$
y
= \ln(1 + \exp(\max(\min(x, threshold), threshold))$
$
out
= \ln(1 + \exp(\max(\min(x, threshold), threshold))$
)DOC"
);
}
...
...
@@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
ELUOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of ELU operator"
);
AddOutput
(
"
Y
"
,
"Output of ELU operator"
);
AddOutput
(
"
Out
"
,
"Output of ELU operator"
);
AddAttr
<
float
>
(
"alpha"
,
"The alpha value of ELU"
).
SetDefault
(
1.0
f
);
AddComment
(
R"DOC(
ELU Activation Operator.
...
...
@@ -388,7 +388,7 @@ ELU Activation Operator.
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.
$
y
= \max(0, x) + \min(0, \alpha * (e^x - 1))$
$
out
= \max(0, x) + \min(0, \alpha * (e^x - 1))$
)DOC"
);
}
...
...
@@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
Relu6OpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Relu6 operator"
);
AddOutput
(
"
Y
"
,
"Output of Relu6 operator"
);
AddOutput
(
"
Out
"
,
"Output of Relu6 operator"
);
AddAttr
<
float
>
(
"threshold"
,
"The threshold value of Relu6"
)
.
SetDefault
(
6.0
f
);
AddComment
(
R"DOC(
Relu6 Activation Operator.
$
y
= \min(\max(0, x), 6)$
$
out
= \min(\max(0, x), 6)$
)DOC"
);
}
...
...
@@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker {
PowOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Pow operator"
);
AddOutput
(
"
Y
"
,
"Output of Pow operator"
);
AddOutput
(
"
Out
"
,
"Output of Pow operator"
);
AddAttr
<
float
>
(
"factor"
,
"The exponential factor of Pow"
).
SetDefault
(
1.0
f
);
AddComment
(
R"DOC(
Pow Activation Operator.
$
y
= x^{factor}$
$
out
= x^{factor}$
)DOC"
);
}
...
...
@@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
STanhOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of STanh operator"
);
AddOutput
(
"
Y
"
,
"Output of STanh operator"
);
AddOutput
(
"
Out
"
,
"Output of STanh operator"
);
AddAttr
<
float
>
(
"scale_a"
,
"The scale parameter of a for the input"
)
.
SetDefault
(
2.0
f
/
3.0
f
);
AddAttr
<
float
>
(
"scale_b"
,
"The scale parameter of b for the input"
)
...
...
@@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
STanh Activation Operator.
$$
y
= b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
$$
out
= b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
)DOC"
);
}
...
...
@@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
ThresholdedReluOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of ThresholdedRelu operator"
);
AddOutput
(
"
Y
"
,
"Output of ThresholdedRelu operator"
);
AddOutput
(
"
Out
"
,
"Output of ThresholdedRelu operator"
);
AddAttr
<
float
>
(
"threshold"
,
"The threshold location of activation"
)
.
SetDefault
(
1.0
f
);
AddComment
(
R"DOC(
ThresholdedRelu Activation Operator.
$$
y
= \begin{cases}
out
= \begin{cases}
x, \text{if } x > threshold \\
0, \text{otherwise}
\end{cases}
...
...
@@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
HardSigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of HardSigmoid operator"
);
AddOutput
(
"
Y
"
,
"Output of HardSigmoid operator"
);
AddOutput
(
"
Out
"
,
"Output of HardSigmoid operator"
);
AddAttr
<
float
>
(
"slope"
,
"Slope for linear approximation of sigmoid"
)
.
SetDefault
(
0.2
f
);
AddAttr
<
float
>
(
"offset"
,
"Offset for linear approximation of sigmoid"
)
...
...
@@ -484,7 +484,7 @@ HardSigmoid Activation Operator.
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
$
y
= \max(0, \min(1, slope * x + shift))$
$
out
= \max(0, \min(1, slope * x + shift))$
The slope should be positive. The offset can be either positive or negative.
The default slope and shift are set according to the above reference.
...
...
@@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
SwishOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input of Swish operator"
);
AddOutput
(
"
Y
"
,
"Output of Swish operator"
);
AddOutput
(
"
Out
"
,
"Output of Swish operator"
);
AddAttr
<
float
>
(
"beta"
,
"Constant beta of swish operator"
).
SetDefault
(
1.0
f
);
AddComment
(
R"DOC(
Swish Activation Operator.
$$
y
= \frac{x}{1 + e^{- \beta x}}$$
$$
out
= \frac{x}{1 + e^{- \beta x}}$$
)DOC"
);
}
...
...
paddle/operators/activation_op.h
浏览文件 @
6f5e64af
此差异已折叠。
点击以展开。
paddle/operators/array_operator.h
浏览文件 @
6f5e64af
...
...
@@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ
(
i_tensor
.
numel
(),
1
);
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
size_t
offset
;
if
(
platform
::
is_gpu_place
(
i_tensor
.
place
()))
{
...
...
paddle/operators/array_to_lod_tensor_op.cc
浏览文件 @
6f5e64af
...
...
@@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
}
auto
slice
=
out
->
Slice
(
out_offset
,
out_offset
+
len
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
CopyFrom
(
x
[
x_idx
].
Slice
(
start_offset
,
end_offset
),
place
,
dev_ctx
,
&
slice
);
...
...
paddle/operators/assign_op.cc
浏览文件 @
6f5e64af
...
...
@@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase {
out
!=
nullptr
,
"The Output(Out) should not be null if the Input(X) is set."
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
VisitVarType
(
*
x
,
AssignFunctor
(
out
,
dev_ctx
));
}
...
...
paddle/operators/beam_search_decode_op.cc
浏览文件 @
6f5e64af
...
...
@@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
dev_place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
);
...
...
paddle/operators/cond_op.cc
浏览文件 @
6f5e64af
...
...
@@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
void
CondOp
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
PrepareDataForSubnet
(
scope
,
dev_ctx
);
std
::
vector
<
framework
::
Scope
*>&
sub_scopes
=
GetSubScopes
(
scope
);
...
...
paddle/operators/feed_op.cc
浏览文件 @
6f5e64af
...
...
@@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase {
auto
*
out_item
=
out_var
->
GetMutable
<
framework
::
FeedFetchType
>
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
CopyFrom
(
feed_item
,
place
,
dev_ctx
,
out_item
);
out_item
->
set_lod
(
feed_item
.
lod
());
...
...
paddle/operators/fetch_op.cc
浏览文件 @
6f5e64af
...
...
@@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
CopyFrom
(
src_item
,
platform
::
CPUPlace
(),
dev_ctx
,
&
dst_item
);
dev_ctx
.
Wait
();
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
6f5e64af
...
...
@@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase {
out
.
mutable_data
(
dev_place
,
framework
::
ToTypeIndex
(
data_type
));
}
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
dev_place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
math
::
set_constant
(
dev_ctx
,
&
out
,
value
);
}
};
...
...
paddle/operators/fill_op.cc
浏览文件 @
6f5e64af
...
...
@@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase {
if
(
!
force_cpu
&&
platform
::
is_gpu_place
(
place
))
{
// Copy tensor to out
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
CopyFrom
(
tensor
,
place
,
dev_ctx
,
&
out
);
}
}
...
...
paddle/operators/load_op.cc
浏览文件 @
6f5e64af
...
...
@@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase {
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
DeserializeFromStream
(
fin
,
tensor
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
if
(
platform
::
is_gpu_place
(
place
))
{
// copy CPU to GPU
...
...
paddle/operators/lod_tensor_to_array_op.cc
浏览文件 @
6f5e64af
...
...
@@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
auto
slice
=
out
[
i
].
Slice
(
static_cast
<
int
>
(
offset
),
static_cast
<
int
>
(
offset
+
len
));
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
CopyFrom
(
x
.
Slice
(
static_cast
<
int
>
(
each_range
.
begin
),
static_cast
<
int
>
(
each_range
.
end
)),
...
...
paddle/operators/merge_lod_tensor_op.cc
浏览文件 @
6f5e64af
...
...
@@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
dev_place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
mask
=
scope
.
FindVar
(
Input
(
"Mask"
))
->
Get
<
framework
::
LoDTensor
>
();
...
...
paddle/operators/nccl_op_test.cu.cc
浏览文件 @
6f5e64af
...
...
@@ -305,7 +305,7 @@ int main(int argc, char **argv) {
}
VLOG
(
0
)
<<
" DeviceCount "
<<
count
;
paddle
::
platform
::
DeviceContextPool
::
Create
(
places
);
paddle
::
platform
::
DeviceContextPool
::
Init
(
places
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
6f5e64af
...
...
@@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase {
false
/*create_local_scope*/
);
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
// Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
...
...
@@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase {
auto
*
program
=
block
->
Program
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
for
(
size_t
step_id
=
0
;
step_id
<
seq_len
;
++
step_id
)
{
size_t
seq_offset
=
reverse
?
step_id
:
seq_len
-
step_id
-
1
;
...
...
paddle/operators/reorder_lod_tensor_by_rank_op.cc
浏览文件 @
6f5e64af
...
...
@@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
auto
x_sliced
=
x
.
Slice
(
x_offset
,
x_offset
+
len
);
auto
out_sliced
=
out
->
Slice
(
out_offset
,
out_offset
+
len
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
CopyFrom
(
x_sliced
,
out_sliced
.
place
(),
dev_ctx
,
&
out_sliced
);
out_offset
+=
len
;
return
out_offset
;
...
...
paddle/operators/save_op.cc
浏览文件 @
6f5e64af
...
...
@@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase {
auto
&
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
SerializeToStream
(
fout
,
tensor
,
dev_ctx
);
}
...
...
paddle/operators/shrink_rnn_memory_op.cc
浏览文件 @
6f5e64af
...
...
@@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
dx_tensor
.
mutable_data
(
x_tensor
.
place
(),
x_tensor
.
type
());
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
if
(
dout_var
==
nullptr
)
{
// dx_tensor fill zero
math
::
set_constant
(
dev_ctx
,
&
dx_tensor
,
0.0
f
);
...
...
paddle/operators/softmax_op.cc
浏览文件 @
6f5e64af
...
...
@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"
Y
"
),
"Output(
Y
) of SoftmaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"
Out
"
),
"Output(
Out
) of SoftmaxOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE
(
x_dims
.
size
()
==
2UL
,
"The input of softmax op must be a matrix."
);
ctx
->
SetOutputDim
(
"
Y
"
,
x_dims
);
ctx
->
SetOutputDim
(
"
Out
"
,
x_dims
);
}
};
...
...
@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."
);
AddOutput
(
"
Y
"
,
"The normalized values with the same shape as X."
);
AddOutput
(
"
Out
"
,
"The normalized values with the same shape as X."
);
AddComment
(
R"DOC(
Softmax Operator.
...
...
@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
operator.
For each row $i$ and each column $j$ in Input(X), we have:
$$
Y
[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
$$
Out
[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC"
);
}
...
...
@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"
Y"
),
"Input(Y
) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"
Y
"
)),
"Input(
Y
@GRAD) should be not null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"
Y
"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"
Y
"
)),
"Input(
Y
) and its gradients should have a same shape."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"
Out"
),
"Input(Out
) should be not null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"
Out
"
)),
"Input(
Out
@GRAD) should be not null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"
Out
"
),
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"
Out
"
)),
"Input(
Out
) and its gradients should have a same shape."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
...
...
paddle/operators/softmax_op.h
浏览文件 @
6f5e64af
...
...
@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
X
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
Y
=
context
.
Output
<
Tensor
>
(
"Y
"
);
auto
*
Out
=
context
.
Output
<
Tensor
>
(
"Out
"
);
// allocate memory on device.
Y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
Out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SoftmaxFunctor
<
DeviceContext
,
T
>
()(
context
.
template
device_context
<
DeviceContext
>(),
X
,
Y
);
context
.
template
device_context
<
DeviceContext
>(),
X
,
Out
);
}
};
...
...
@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T>
class
SoftmaxGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
Y
=
context
.
Input
<
Tensor
>
(
"Y
"
);
auto
*
d
Y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y
"
));
auto
*
Out
=
context
.
Input
<
Tensor
>
(
"Out
"
);
auto
*
d
Out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out
"
));
auto
*
dX
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
// allocate memory on device.
dX
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SoftmaxGradFunctor
<
DeviceContext
,
T
>
()(
context
.
template
device_context
<
DeviceContext
>(),
Y
,
dY
,
dX
);
context
.
template
device_context
<
DeviceContext
>(),
Out
,
dOut
,
dX
);
}
};
...
...
paddle/operators/split_lod_tensor_op.cc
浏览文件 @
6f5e64af
...
...
@@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase {
auto
&
x_lod
=
x
.
lod
();
auto
&
mask_dim
=
mask
.
dims
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
dev_place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
std
::
unique_ptr
<
framework
::
LoDTensor
>
cpu_mask
{
new
framework
::
LoDTensor
()};
if
(
platform
::
is_cpu_place
(
mask
.
place
()))
{
...
...
paddle/operators/tensor_array_read_write_op.cc
浏览文件 @
6f5e64af
...
...
@@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp {
if
(
x_tensor
.
memory_size
()
>
0
)
{
auto
*
out_tensor
=
&
out
->
at
(
offset
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
CopyFrom
(
x_tensor
,
place
,
dev_ctx
,
out_tensor
);
out_tensor
->
set_lod
(
x_tensor
.
lod
());
...
...
@@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp {
auto
*
out_tensor
=
out
->
GetMutable
<
framework
::
LoDTensor
>
();
size_t
offset
=
GetOffset
(
scope
,
place
);
if
(
offset
<
x_array
.
size
())
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
auto
&
dev_ctx
=
*
pool
.
Borrow
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
framework
::
CopyFrom
(
x_array
[
offset
],
place
,
dev_ctx
,
out_tensor
);
out_tensor
->
set_lod
(
x_array
[
offset
].
lod
());
}
else
{
...
...
paddle/platform/device_context.cc
浏览文件 @
6f5e64af
...
...
@@ -17,7 +17,7 @@ namespace platform {
DeviceContextPool
*
DeviceContextPool
::
pool
=
nullptr
;
const
platform
::
DeviceContext
*
DeviceContextPool
::
Borrow
(
const
platform
::
DeviceContext
*
DeviceContextPool
::
Get
(
const
platform
::
Place
&
place
)
{
auto
it
=
device_contexts_
.
find
(
place
);
if
(
it
==
device_contexts_
.
end
())
{
...
...
@@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow(
return
it
->
second
;
}
std
::
vector
<
const
platform
::
DeviceContext
*>
DeviceContextPool
::
Borrow
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
PADDLE_ENFORCE_GT
(
places
.
size
(),
0
);
PADDLE_ENFORCE_LE
(
places
.
size
(),
device_contexts_
.
size
());
std
::
vector
<
const
platform
::
DeviceContext
*>
borrowed_contexts
;
for
(
auto
&
place
:
places
)
{
auto
it
=
device_contexts_
.
find
(
place
);
if
(
it
!=
device_contexts_
.
end
())
{
borrowed_contexts
.
emplace_back
(
it
->
second
);
}
else
{
PADDLE_THROW
(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option"
);
}
}
return
borrowed_contexts
;
}
DeviceContextPool
::
DeviceContextPool
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
PADDLE_ENFORCE_GT
(
places
.
size
(),
0
);
...
...
paddle/platform/device_context.h
浏览文件 @
6f5e64af
...
...
@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
};
template
<
typename
Place
>
struct
DefaultDeviceContextType
;
template
<
>
struct
DefaultDeviceContextType
<
platform
::
CPUPlace
>
{
using
TYPE
=
CPUDeviceContext
;
};
#ifdef PADDLE_WITH_CUDA
class
EigenCudaStreamDevice
;
...
...
@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t
cublas_handle_
;
};
template
<
>
struct
DefaultDeviceContextType
<
platform
::
CUDAPlace
>
{
using
T
=
CUDADeviceContext
;
};
class
CUDNNDeviceContext
:
public
CUDADeviceContext
{
public:
explicit
CUDNNDeviceContext
(
CUDAPlace
place
);
...
...
@@ -109,13 +122,13 @@ class DeviceContextPool {
public:
explicit
DeviceContextPool
(
const
std
::
vector
<
platform
::
Place
>&
places
);
static
DeviceContextPool
&
Get
()
{
static
DeviceContextPool
&
Instance
()
{
PADDLE_ENFORCE_NOT_NULL
(
pool
,
"Need to Create DeviceContextPool first!"
);
return
*
pool
;
}
/*! \brief Create should only called by Init function */
static
DeviceContextPool
&
Create
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
static
DeviceContextPool
&
Init
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
if
(
pool
==
nullptr
)
{
pool
=
new
DeviceContextPool
(
places
);
}
...
...
@@ -123,13 +136,14 @@ class DeviceContextPool {
}
/*! \brief Return handle of single device context. */
const
platform
::
DeviceContext
*
Borrow
(
const
platform
::
Place
&
place
);
/*! \brief Return handle of multi-device context. */
std
::
vector
<
const
platform
::
DeviceContext
*>
Borrow
(
const
std
::
vector
<
platform
::
Place
>&
places
);
const
platform
::
DeviceContext
*
Get
(
const
platform
::
Place
&
place
);
~
DeviceContextPool
()
{}
template
<
typename
Place
>
const
typename
DefaultDeviceContextType
<
Place
>::
TYPE
*
GetByPlace
(
const
Place
&
place
)
{
return
reinterpret_cast
<
const
typename
DefaultDeviceContextType
<
Place
>::
TYPE
*>
(
Get
(
place
));
}
private:
static
DeviceContextPool
*
pool
;
...
...
paddle/platform/device_context_test.cu
浏览文件 @
6f5e64af
...
...
@@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) {
using
paddle
::
platform
::
CPUPlace
;
using
paddle
::
platform
::
CUDAPlace
;
DeviceContextPool
&
pool
=
DeviceContextPool
::
Get
();
auto
cpu_dev_ctx1
=
pool
.
Borrow
(
CPUPlace
());
auto
cpu_dev_ctx2
=
pool
.
Borrow
(
CPUPlace
());
EXPECT_TRUE
(
cpu_dev_ctx2
==
cpu_dev_ctx1
);
DeviceContextPool
&
pool
=
DeviceContextPool
::
Instance
();
auto
cpu_dev_ctx1
=
pool
.
Get
(
CPUPlace
());
auto
cpu_dev_ctx2
=
pool
.
Get
(
CPUPlace
());
ASSERT_EQ
(
cpu_dev_ctx2
,
cpu_dev_ctx1
);
std
::
vector
<
Place
>
gpu_places
;
int
count
=
paddle
::
platform
::
GetCUDADeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
gpu_places
.
emplace_back
(
CUDAPlace
(
i
));
}
auto
dev_ctxs
=
pool
.
Borrow
(
gpu_places
);
for
(
size_t
i
=
0
;
i
<
dev_ctxs
.
size
();
++
i
)
{
auto
*
dev_ctx
=
static_cast
<
const
CUDADeviceContext
*>
(
dev_ctxs
[
i
]);
// check same as CUDAPlace(i)
CUDAPlace
place
=
boost
::
get
<
CUDAPlace
>
(
dev_ctx
->
GetPlace
());
EXPECT_EQ
(
place
.
GetDeviceId
(),
static_cast
<
int
>
(
i
));
auto
dev_ctx
=
pool
.
Get
(
CUDAPlace
(
i
));
ASSERT_NE
(
dev_ctx
,
nullptr
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
int
dev_count
=
paddle
::
platform
::
GetCUDADeviceCount
();
if
(
dev_count
<=
1
)
{
LOG
(
WARNING
)
<<
"Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<<
dev_count
;
return
0
;
}
std
::
vector
<
paddle
::
platform
::
Place
>
places
;
places
.
emplace_back
(
paddle
::
platform
::
CPUPlace
());
...
...
@@ -109,7 +94,7 @@ int main(int argc, char** argv) {
}
VLOG
(
0
)
<<
" DeviceCount "
<<
count
;
paddle
::
platform
::
DeviceContextPool
::
Create
(
places
);
paddle
::
platform
::
DeviceContextPool
::
Init
(
places
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
return
RUN_ALL_TESTS
();
...
...
paddle/platform/nccl_test.cu
浏览文件 @
6f5e64af
...
...
@@ -144,7 +144,7 @@ int main(int argc, char** argv) {
}
VLOG
(
0
)
<<
" DeviceCount "
<<
count
;
paddle
::
platform
::
DeviceContextPool
::
Create
(
places
);
paddle
::
platform
::
DeviceContextPool
::
Init
(
places
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
return
RUN_ALL_TESTS
();
...
...
paddle/platform/place.h
浏览文件 @
6f5e64af
...
...
@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
namespace
paddle
{
...
...
@@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Place
&
);
template
<
typename
Visitor
>
struct
PlaceVisitorWrapper
:
public
boost
::
static_visitor
<
typename
Visitor
::
result_type
>
{
const
Visitor
&
visitor_
;
explicit
PlaceVisitorWrapper
(
const
Visitor
&
visitor
)
:
visitor_
(
visitor
)
{}
typename
Visitor
::
result_type
operator
()(
const
CPUPlace
&
cpu
)
const
{
return
visitor_
(
cpu
);
}
typename
Visitor
::
result_type
operator
()(
const
CUDAPlace
&
cuda
)
const
{
#ifdef PADDLE_WITH_CUDA
return
visitor_
(
cuda
);
#else
PADDLE_THROW
(
"Paddle is not compiled with CUDA. Cannot visit cuda device"
);
return
typename
Visitor
::
result_type
();
#endif
}
};
template
<
typename
Visitor
>
typename
Visitor
::
result_type
VisitPlace
(
const
Place
&
place
,
const
Visitor
&
visitor
)
{
return
boost
::
apply_visitor
(
PlaceVisitorWrapper
<
Visitor
>
(
visitor
),
place
);
}
}
// namespace platform
}
// namespace paddle
paddle/pybind/tensor_py.h
浏览文件 @
6f5e64af
...
...
@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto
*
dst_ptr
=
static_cast
<
void
*>
(
dst_tensor
.
mutable_data
<
CUR_TYPE
>
(
tensor
.
dims
(),
platform
::
CPUPlace
()));
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Borrow
(
tensor
.
place
()));
pool
.
Get
(
tensor
.
place
()));
paddle
::
platform
::
GpuMemcpyAsync
(
dst_ptr
,
src_ptr
,
sizeof
(
CUR_TYPE
)
*
tensor
.
numel
(),
...
...
@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray(
self
.
Resize
(
framework
::
make_ddim
(
dims
));
auto
*
dst
=
self
.
mutable_data
<
T
>
(
place
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Get
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Borrow
(
place
));
static_cast
<
const
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
place
));
paddle
::
platform
::
GpuMemcpyAsync
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
(),
cudaMemcpyHostToDevice
,
dev_ctx
->
stream
());
}
...
...
python/paddle/v2/fluid/__init__.py
浏览文件 @
6f5e64af
...
...
@@ -36,7 +36,7 @@ def __read_gflags_from_env__():
"""
import
sys
import
core
read_env_flags
=
[
'use_pinned_memory'
]
read_env_flags
=
[
'use_pinned_memory'
,
'check_nan_inf'
]
if
core
.
is_compile_gpu
():
read_env_flags
.
append
(
'fraction_of_gpu_memory_to_use'
)
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
...
...
python/paddle/v2/fluid/io.py
浏览文件 @
6f5e64af
...
...
@@ -180,10 +180,22 @@ def save_inference_model(dirname,
:return: None
"""
if
isinstance
(
feeded_var_names
,
basestring
):
feeded_var_names
=
[
feeded_var_names
]
else
:
if
not
(
bool
(
feeded_var_names
)
and
all
(
isinstance
(
name
,
basestring
)
for
name
in
feeded_var_names
)):
raise
ValueError
(
"'feed_var_names' should be a list of str."
)
if
isinstance
(
target_vars
,
Variable
):
feeded_var_names
=
[
feeded_var_names
]
else
:
if
not
(
bool
(
target_vars
)
and
all
(
isinstance
(
var
,
Variable
)
for
var
in
target_vars
)):
raise
ValueError
(
"'target_vars' should be a list of Variable."
)
if
main_program
is
None
:
main_program
=
default_main_program
()
if
not
isinstance
(
target_vars
,
list
):
target_vars
=
[
target_vars
]
if
not
os
.
path
.
isdir
(
dirname
):
os
.
makedirs
(
dirname
)
...
...
python/paddle/v2/fluid/layer_helper.py
浏览文件 @
6f5e64af
...
...
@@ -184,7 +184,7 @@ class LayerHelper(object):
self
.
append_op
(
type
=
act_type
,
inputs
=
{
"X"
:
[
input_var
]},
outputs
=
{
"
Y
"
:
[
tmp
]},
outputs
=
{
"
Out
"
:
[
tmp
]},
attrs
=
act
)
return
tmp
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
6f5e64af
...
...
@@ -386,7 +386,8 @@ def square_error_cost(input, label, **kwargs):
square_out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
type
=
'square'
,
inputs
=
{
'X'
:
[
minus_out
]},
outputs
=
{
'Y'
:
[
square_out
]})
type
=
'square'
,
inputs
=
{
'X'
:
[
minus_out
]},
outputs
=
{
'Out'
:
[
square_out
]})
return
square_out
...
...
@@ -604,7 +605,7 @@ def sequence_pool(input, pool_type, **kwargs):
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool.
...
...
@@ -616,7 +617,7 @@ def sequence_pool(input, pool_type, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
avg_x = fluid.layers.sequence_pool(input=x, pool_type='average')
...
...
@@ -654,7 +655,7 @@ def sequence_first_step(input, **kwargs):
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
...
...
@@ -664,7 +665,7 @@ def sequence_first_step(input, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_first_step = fluid.layers.sequence_first_step(input=x)
...
...
@@ -687,7 +688,7 @@ def sequence_last_step(input, **kwargs):
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
...
...
@@ -697,7 +698,7 @@ def sequence_last_step(input, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_last_step = fluid.layers.sequence_last_step(input=x)
...
...
@@ -1132,7 +1133,7 @@ def reduce_sum(input, dim=None, keep_dim=False):
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python
...
...
@@ -1176,7 +1177,7 @@ def reduce_mean(input, dim=None, keep_dim=False):
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python
...
...
python/paddle/v2/fluid/tests/test_activation_op.py
浏览文件 @
6f5e64af
...
...
@@ -10,13 +10,13 @@ class TestExp(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
exp
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
exp
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestSigmoid
(
OpTest
):
...
...
@@ -25,13 +25,13 @@ class TestSigmoid(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
1
/
(
1
+
np
.
exp
(
-
self
.
inputs
[
'X'
]))}
self
.
outputs
=
{
'
Out
'
:
1
/
(
1
+
np
.
exp
(
-
self
.
inputs
[
'X'
]))}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.008
)
class
TestLogSigmoid
(
OpTest
):
...
...
@@ -40,13 +40,13 @@ class TestLogSigmoid(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
log
(
1
/
(
1
+
np
.
exp
(
-
self
.
inputs
[
'X'
])))}
self
.
outputs
=
{
'
Out
'
:
np
.
log
(
1
/
(
1
+
np
.
exp
(
-
self
.
inputs
[
'X'
])))}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.008
)
class
TestTanh
(
OpTest
):
...
...
@@ -55,13 +55,13 @@ class TestTanh(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
tanh
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
tanh
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestTanhShrink
(
OpTest
):
...
...
@@ -70,13 +70,13 @@ class TestTanhShrink(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
self
.
inputs
[
'X'
]
-
np
.
tanh
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
self
.
inputs
[
'X'
]
-
np
.
tanh
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.008
)
class
TestHardShrink
(
OpTest
):
...
...
@@ -90,13 +90,13 @@ class TestHardShrink(OpTest):
t
=
np
.
copy
(
x
)
t
[(
t
>=
-
threshold
)
&
(
t
<=
threshold
)]
=
0
self
.
outputs
=
{
'
Y
'
:
t
}
self
.
outputs
=
{
'
Out
'
:
t
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.005
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.005
)
class
TestSoftShrink
(
OpTest
):
...
...
@@ -110,13 +110,13 @@ class TestSoftShrink(OpTest):
y
=
np
.
copy
(
self
.
inputs
[
'X'
])
y
=
(
y
<
-
lambda_val
)
*
(
y
+
lambda_val
)
+
(
y
>
lambda_val
)
*
(
y
-
lambda_val
)
self
.
outputs
=
{
'
Y
'
:
y
}
self
.
outputs
=
{
'
Out
'
:
y
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestSqrt
(
OpTest
):
...
...
@@ -125,13 +125,13 @@ class TestSqrt(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
sqrt
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
sqrt
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestAbs
(
OpTest
):
...
...
@@ -144,13 +144,13 @@ class TestAbs(OpTest):
# we should avoid this
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'
Y
'
:
np
.
abs
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
abs
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestCeil
(
OpTest
):
...
...
@@ -158,13 +158,13 @@ class TestCeil(OpTest):
self
.
op_type
=
"ceil"
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'
Y
'
:
np
.
ceil
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
ceil
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestFloor
(
OpTest
):
...
...
@@ -173,13 +173,13 @@ class TestFloor(OpTest):
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
x
}
# numpy floor need +1
self
.
outputs
=
{
'
Y
'
:
np
.
floor
(
self
.
inputs
[
'X'
])
+
1.0
}
self
.
outputs
=
{
'
Out
'
:
np
.
floor
(
self
.
inputs
[
'X'
])
+
1.0
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestRound
(
OpTest
):
...
...
@@ -187,13 +187,13 @@ class TestRound(OpTest):
self
.
op_type
=
"round"
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'
Y
'
:
np
.
round
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
round
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestRelu
(
OpTest
):
...
...
@@ -203,13 +203,13 @@ class TestRelu(OpTest):
# The same reason with TestAbs
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'
Y
'
:
np
.
maximum
(
self
.
inputs
[
'X'
],
0
)}
self
.
outputs
=
{
'
Out
'
:
np
.
maximum
(
self
.
inputs
[
'X'
],
0
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestBRelu
(
OpTest
):
...
...
@@ -227,13 +227,13 @@ class TestBRelu(OpTest):
t
=
np
.
copy
(
x
)
t
[
t
<
t_min
]
=
t_min
t
[
t
>
t_max
]
=
t_max
self
.
outputs
=
{
'
Y
'
:
t
}
self
.
outputs
=
{
'
Out
'
:
t
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.02
)
class
TestRelu6
(
OpTest
):
...
...
@@ -248,14 +248,14 @@ class TestRelu6(OpTest):
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
outputs
=
{
'
Y
'
:
np
.
minimum
(
np
.
maximum
(
self
.
inputs
[
'X'
],
0
),
threshold
)
'
Out
'
:
np
.
minimum
(
np
.
maximum
(
self
.
inputs
[
'X'
],
0
),
threshold
)
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.02
)
class
TestSoftRelu
(
OpTest
):
...
...
@@ -271,13 +271,13 @@ class TestSoftRelu(OpTest):
t
=
np
.
copy
(
x
)
t
[
t
<
-
threshold
]
=
-
threshold
t
[
t
>
threshold
]
=
threshold
self
.
outputs
=
{
'
Y
'
:
np
.
log
((
np
.
exp
(
t
)
+
1
))}
self
.
outputs
=
{
'
Out
'
:
np
.
log
((
np
.
exp
(
t
)
+
1
))}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.02
)
class
TestELU
(
OpTest
):
...
...
@@ -290,27 +290,27 @@ class TestELU(OpTest):
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'alpha'
:
alpha
}
self
.
outputs
=
{
'
Y
'
:
np
.
maximum
(
0
,
x
)
+
np
.
minimum
(
0
,
alpha
*
(
np
.
exp
(
x
)
-
1
))
'
Out
'
:
np
.
maximum
(
0
,
x
)
+
np
.
minimum
(
0
,
alpha
*
(
np
.
exp
(
x
)
-
1
))
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.02
)
class
TestReciprocal
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"reciprocal"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
"float32"
)}
self
.
outputs
=
{
'
Y
'
:
np
.
reciprocal
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
reciprocal
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.01
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.01
)
class
TestLog
(
OpTest
):
...
...
@@ -319,13 +319,13 @@ class TestLog(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
log
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
log
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestSquare
(
OpTest
):
...
...
@@ -334,13 +334,13 @@ class TestSquare(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
square
(
self
.
inputs
[
'X'
])}
self
.
outputs
=
{
'
Out
'
:
np
.
square
(
self
.
inputs
[
'X'
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestPow
(
OpTest
):
...
...
@@ -348,13 +348,13 @@ class TestPow(OpTest):
self
.
op_type
=
"pow"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
"float32"
)}
self
.
attrs
=
{
'factor'
:
3.0
}
self
.
outputs
=
{
'
Y
'
:
np
.
power
(
self
.
inputs
[
'X'
],
3
)}
self
.
outputs
=
{
'
Out
'
:
np
.
power
(
self
.
inputs
[
'X'
],
3
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.02
)
class
TestSTanh
(
OpTest
):
...
...
@@ -366,13 +366,13 @@ class TestSTanh(OpTest):
scale_a
=
2.0
/
3.0
scale_b
=
1.7159
self
.
attrs
=
{
'scale_a'
:
scale_a
,
'scale_b'
:
scale_b
}
self
.
outputs
=
{
'
Y
'
:
scale_b
*
np
.
tanh
(
self
.
inputs
[
'X'
]
*
scale_a
)}
self
.
outputs
=
{
'
Out
'
:
scale_b
*
np
.
tanh
(
self
.
inputs
[
'X'
]
*
scale_a
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestSoftplus
(
OpTest
):
...
...
@@ -381,13 +381,13 @@ class TestSoftplus(OpTest):
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
"float64"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
log
(
1
+
np
.
exp
(
self
.
inputs
[
'X'
]))}
self
.
outputs
=
{
'
Out
'
:
np
.
log
(
1
+
np
.
exp
(
self
.
inputs
[
'X'
]))}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestSoftsign
(
OpTest
):
...
...
@@ -397,14 +397,14 @@ class TestSoftsign(OpTest):
'X'
:
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
divide
(
self
.
inputs
[
'X'
],
1
+
np
.
abs
(
self
.
inputs
[
'X'
]))
'
Out
'
:
np
.
divide
(
self
.
inputs
[
'X'
],
1
+
np
.
abs
(
self
.
inputs
[
'X'
]))
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.007
)
class
TestThresholdedRelu
(
OpTest
):
...
...
@@ -419,13 +419,13 @@ class TestThresholdedRelu(OpTest):
self
.
inputs
=
{
'X'
:
X
}
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
outputs
=
{
'
Y
'
:
(
X
>
threshold
)
*
X
}
self
.
outputs
=
{
'
Out
'
:
(
X
>
threshold
)
*
X
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
self
.
relative_error
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
self
.
relative_error
)
class
TestHardSigmoid
(
OpTest
):
...
...
@@ -447,13 +447,13 @@ class TestHardSigmoid(OpTest):
upper_threshold
-
0.2
temp
=
X
*
slope
+
offset
self
.
outputs
=
{
'
Y
'
:
np
.
maximum
(
0.0
,
np
.
minimum
(
1.0
,
temp
))}
self
.
outputs
=
{
'
Out
'
:
np
.
maximum
(
0.0
,
np
.
minimum
(
1.0
,
temp
))}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.002
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.002
)
class
TestSwish
(
OpTest
):
...
...
@@ -462,13 +462,13 @@ class TestSwish(OpTest):
X
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
}
self
.
attrs
=
{
'beta'
:
2.3
}
self
.
outputs
=
{
'
Y
'
:
X
*
expit
(
self
.
attrs
[
'beta'
]
*
X
)}
self
.
outputs
=
{
'
Out
'
:
X
*
expit
(
self
.
attrs
[
'beta'
]
*
X
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'
Out
'
,
max_relative_error
=
0.008
)
if
__name__
==
"__main__"
:
...
...
python/paddle/v2/fluid/tests/test_net.py
浏览文件 @
6f5e64af
...
...
@@ -7,7 +7,7 @@ def fc(X, W, Y):
ret_v
=
core
.
Net
.
create
()
ret_v
.
append_op
(
Operator
(
"mul"
,
X
=
"X"
,
Y
=
"W"
,
Out
=
"pre_activation"
))
ret_v
.
append_op
(
Operator
(
"sigmoid"
,
X
=
"pre_activation"
,
Y
=
Y
))
ret_v
.
append_op
(
Operator
(
"sigmoid"
,
X
=
"pre_activation"
,
Out
=
Y
))
ret_v
.
complete_add_op
(
True
)
return
ret_v
...
...
@@ -30,7 +30,7 @@ Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{
Y
[fc.out]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{
Out
[fc.out]}.
'''
self
.
assertEqual
(
expected
,
"
\n
"
+
str
(
net
))
...
...
python/paddle/v2/fluid/tests/test_softmax_op.py
浏览文件 @
6f5e64af
...
...
@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest):
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
10
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'
Y
'
:
np
.
apply_along_axis
(
stable_softmax
,
1
,
self
.
inputs
[
'X'
])
'
Out
'
:
np
.
apply_along_axis
(
stable_softmax
,
1
,
self
.
inputs
[
'X'
])
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'
Y
'
)
self
.
check_grad
([
'X'
],
'
Out
'
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录