Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3b32eb9e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3b32eb9e
编写于
11月 09, 2017
作者:
Y
Yang yaming
提交者:
GitHub
11月 09, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5488 from pkuyym/fix-5417
Add ScaleSubRegion Layer.
上级
91855659
930d2e89
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
714 addition
and
2 deletion
+714
-2
paddle/function/CMakeLists.txt
paddle/function/CMakeLists.txt
+1
-0
paddle/function/FunctionTest.h
paddle/function/FunctionTest.h
+10
-0
paddle/function/ScaleSubRegionOp.cpp
paddle/function/ScaleSubRegionOp.cpp
+155
-0
paddle/function/ScaleSubRegionOp.h
paddle/function/ScaleSubRegionOp.h
+55
-0
paddle/function/ScaleSubRegionOpGpu.cu
paddle/function/ScaleSubRegionOpGpu.cu
+116
-0
paddle/function/ScaleSubRegionOpTest.cpp
paddle/function/ScaleSubRegionOpTest.cpp
+72
-0
paddle/gserver/layers/ScaleSubRegionLayer.cpp
paddle/gserver/layers/ScaleSubRegionLayer.cpp
+78
-0
paddle/gserver/layers/ScaleSubRegionLayer.h
paddle/gserver/layers/ScaleSubRegionLayer.h
+52
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+32
-0
paddle/math/tests/TensorCheck.h
paddle/math/tests/TensorCheck.h
+1
-1
proto/ModelConfig.proto
proto/ModelConfig.proto
+6
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+19
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+54
-0
python/paddle/trainer_config_helpers/tests/configs/file_list.sh
.../paddle/trainer_config_helpers/tests/configs/file_list.sh
+1
-1
python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr
...sts/configs/protostr/test_scale_sub_region_layer.protostr
+51
-0
python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py
...nfig_helpers/tests/configs/test_scale_sub_region_layer.py
+11
-0
未找到文件。
paddle/function/CMakeLists.txt
浏览文件 @
3b32eb9e
...
@@ -45,6 +45,7 @@ if(WITH_GPU)
...
@@ -45,6 +45,7 @@ if(WITH_GPU)
add_simple_unittest
(
BlockExpandOpTest
)
add_simple_unittest
(
BlockExpandOpTest
)
add_simple_unittest
(
CropOpTest
)
add_simple_unittest
(
CropOpTest
)
add_simple_unittest
(
SwitchOpTest
)
add_simple_unittest
(
SwitchOpTest
)
add_simple_unittest
(
ScaleSubRegionOpTest
)
endif
()
endif
()
add_simple_unittest
(
Im2ColTest
)
add_simple_unittest
(
Im2ColTest
)
...
...
paddle/function/FunctionTest.h
浏览文件 @
3b32eb9e
...
@@ -110,6 +110,7 @@ public:
...
@@ -110,6 +110,7 @@ public:
function2_
(
FunctionBase
::
funcRegistrar_
.
createByType
(
name2
))
{
function2_
(
FunctionBase
::
funcRegistrar_
.
createByType
(
name2
))
{
function1_
->
init
(
config
);
function1_
->
init
(
config
);
function2_
->
init
(
config
);
function2_
->
init
(
config
);
initArgsCallback_
=
nullptr
;
}
}
~
Compare2Function
()
{}
~
Compare2Function
()
{}
...
@@ -170,6 +171,10 @@ public:
...
@@ -170,6 +171,10 @@ public:
*
seq2_
));
*
seq2_
));
}
}
void
registerInitCallback
(
std
::
function
<
void
(
BufferArg
&
,
size_t
)
>
callback
)
{
initArgsCallback_
=
callback
;
}
// output need only contains shape, do not contains data.
// output need only contains shape, do not contains data.
void
addOutputs
(
const
BufferArg
&
output
,
ArgType
argType
=
ASSIGN_TO
)
{
void
addOutputs
(
const
BufferArg
&
output
,
ArgType
argType
=
ASSIGN_TO
)
{
size_t
size
=
size_t
size
=
...
@@ -340,6 +345,10 @@ protected:
...
@@ -340,6 +345,10 @@ protected:
initArg
(
*
func1Inputs_
[
i
]);
initArg
(
*
func1Inputs_
[
i
]);
}
}
if
(
initArgsCallback_
!=
nullptr
)
{
initArgsCallback_
(
*
func1Inputs_
[
i
],
i
);
}
copyArg_
(
*
func1Inputs_
[
i
],
*
func2Inputs_
[
i
]);
copyArg_
(
*
func1Inputs_
[
i
],
*
func2Inputs_
[
i
]);
}
}
}
}
...
@@ -386,6 +395,7 @@ protected:
...
@@ -386,6 +395,7 @@ protected:
std
::
shared_ptr
<
SequenceIdArg
>
seq1_
;
std
::
shared_ptr
<
SequenceIdArg
>
seq1_
;
std
::
shared_ptr
<
SequenceIdArg
>
seq2_
;
std
::
shared_ptr
<
SequenceIdArg
>
seq2_
;
test
::
CopyArgument
<
DType1
,
DType2
>
copyArg_
;
test
::
CopyArgument
<
DType1
,
DType2
>
copyArg_
;
std
::
function
<
void
(
BufferArg
&
,
size_t
)
>
initArgsCallback_
;
};
};
class
CpuGpuFuncCompare
class
CpuGpuFuncCompare
...
...
paddle/function/ScaleSubRegionOp.cpp
0 → 100644
浏览文件 @
3b32eb9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ScaleSubRegionOp.h"
#include "paddle/function/TensorShape.h"
namespace
paddle
{
template
<
>
void
ScaleSubRegion
<
DEVICE_TYPE_CPU
>
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
memcpy
(
outputs
,
inputs
,
number
*
channel
*
height
*
width
*
sizeof
(
real
));
for
(
int
n
=
0
;
n
<
number
;
++
n
)
{
// indices start from 1
int
offset
=
n
*
6
;
for
(
int
c
=
indices
[
offset
]
-
1
;
c
<
indices
[
offset
+
1
];
++
c
)
{
for
(
int
h
=
indices
[
offset
+
2
]
-
1
;
h
<
indices
[
offset
+
3
];
++
h
)
{
for
(
int
w
=
indices
[
offset
+
4
]
-
1
;
w
<
indices
[
offset
+
5
];
++
w
)
{
int
idx
=
((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
;
outputs
[
idx
]
*=
value
;
}
}
}
}
}
template
<
>
void
ScaleSubRegionGrad
<
DEVICE_TYPE_CPU
>
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
for
(
int
n
=
0
;
n
<
number
;
++
n
)
{
for
(
int
c
=
0
;
c
<
channel
;
++
c
)
{
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
int
idx
=
((
n
*
channel
+
c
)
*
height
+
h
)
*
width
+
w
;
int
offset
=
n
*
6
;
if
(
c
>=
(
indices
[
offset
]
-
1
)
&&
c
<=
(
indices
[
offset
+
1
]
-
1
)
&&
h
>=
(
indices
[
offset
+
2
]
-
1
)
&&
h
<=
(
indices
[
offset
+
3
]
-
1
)
&&
w
>=
(
indices
[
offset
+
4
]
-
1
)
&&
w
<=
(
indices
[
offset
+
5
]
-
1
))
{
outGrad
[
idx
]
+=
inGrad
[
idx
]
*
value
;
}
else
{
outGrad
[
idx
]
+=
inGrad
[
idx
];
}
}
}
}
}
}
/**
* \brief For each instance, ScaleSubRegion can be used to multiply a value to
* a specified sub continuous region. By providing start index and end
* index for C/H/W, you can specify the location and shape of the region.
*
* Argument in this Function:
* \param inputs A 4-D tensor with shape [N, C, H, W], only one input.
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
* \param outputs A 4-D tensor with same shape as inputs, output value.
*/
template
<
DeviceType
Device
>
class
ScaleSubRegionFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
conf_
=
config
;
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
2UL
,
inputs
.
size
());
CHECK_EQ
(
1UL
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ASSIGN_TO
);
TensorShape
shape
=
inputs
[
0
].
shape
();
ScaleSubRegion
<
Device
>
(
outputs
[
0
].
data
<
real
>
(),
inputs
[
0
].
data
<
real
>
(),
inputs
[
1
].
data
<
real
>
(),
shape
,
conf_
);
}
private:
FuncConfig
conf_
;
};
/**
* \brief The backward propagation of ScaleSubRegion Function.
*
* Argument in this Function:
* \param inputs A 4-D tensor with shape [N, C, H, W], output gradient.
* \param indices A 2-D tensor with shape [N, 6], indicates the sub region.
* \param outputs A 4-D tensor with shape [N, C, H, W], gradient of input value.
*/
template
<
DeviceType
Device
>
class
ScaleSubRegionGradFunc
:
public
FunctionBase
{
public:
void
init
(
const
FuncConfig
&
config
)
override
{
conf_
=
config
;
}
void
calc
(
const
BufferArgs
&
inputs
,
const
BufferArgs
&
outputs
)
override
{
CHECK_EQ
(
2UL
,
inputs
.
size
());
CHECK_EQ
(
1UL
,
outputs
.
size
());
CHECK_EQ
(
outputs
[
0
].
getArgType
(),
ADD_TO
);
TensorShape
shape
=
inputs
[
0
].
shape
();
ScaleSubRegionGrad
<
Device
>
(
inputs
[
0
].
data
<
real
>
(),
outputs
[
0
].
data
<
real
>
(),
inputs
[
1
].
data
<
real
>
(),
shape
,
conf_
);
}
private:
FuncConfig
conf_
;
};
REGISTER_TYPED_FUNC
(
ScaleSubRegion
,
CPU
,
ScaleSubRegionFunc
);
REGISTER_TYPED_FUNC
(
ScaleSubRegionGrad
,
CPU
,
ScaleSubRegionGradFunc
);
#ifdef PADDLE_WITH_CUDA
REGISTER_TYPED_FUNC
(
ScaleSubRegion
,
GPU
,
ScaleSubRegionFunc
);
REGISTER_TYPED_FUNC
(
ScaleSubRegionGrad
,
GPU
,
ScaleSubRegionGradFunc
);
#endif
}
// namespace paddle
paddle/function/ScaleSubRegionOp.h
0 → 100644
浏览文件 @
3b32eb9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace
paddle
{
/**
* \brief Function to multiply a value to values in specified sub continuous
* region. Indices must be provided to indcate the location and shape of
* the region and the multiplied value is passed by configure variable.
*
*
* \param[out] outputs Output value.
* \param[in] inputs Input data which contains NCHW information.
* \param[in] indices Indices data to indcate the sub region.
* \param[in] shape Tensor shape of input value.
* \param[in] conf Configure variable which contains the multiplied value.
*/
template
<
DeviceType
Device
>
void
ScaleSubRegion
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
);
/**
* \brief Backward propagation function of ScaleSubRegion.
*
* \param[out] inGrad Gradients of previous layer.
* \param[in] outGrad Output gradient.
* \param[in] indices Indices data.
* \param[in] shape The Shape of input tensor.
* \param[in] conf Configure variable.
*/
template
<
DeviceType
Device
>
void
ScaleSubRegionGrad
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
);
}
// namespace paddle
paddle/function/ScaleSubRegionOpGpu.cu
0 → 100644
浏览文件 @
3b32eb9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ScaleSubRegionOp.h"
#include "hl_base.h"
namespace
paddle
{
__global__
void
KeScaleSubRegion
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
real
value
,
int
channel
,
int
height
,
int
width
,
int
nthreads
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
nthreads
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
c
=
(
idx
/
width
/
height
)
%
channel
;
const
int
n
=
idx
/
width
/
height
/
channel
;
const
int
offset
=
n
*
6
;
if
(
c
>=
(
indices
[
offset
]
-
1
)
&&
c
<=
(
indices
[
offset
+
1
]
-
1
)
&&
h
>=
(
indices
[
offset
+
2
]
-
1
)
&&
h
<=
(
indices
[
offset
+
3
]
-
1
)
&&
w
>=
(
indices
[
offset
+
4
]
-
1
)
&&
w
<=
(
indices
[
offset
+
5
]
-
1
))
{
outputs
[
idx
]
=
inputs
[
idx
]
*
value
;
}
else
{
outputs
[
idx
]
=
inputs
[
idx
];
}
}
}
template
<
>
void
ScaleSubRegion
<
DEVICE_TYPE_GPU
>
(
real
*
outputs
,
const
real
*
inputs
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
size_t
nth
=
number
*
channel
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
nth
+
blockSize
-
1
)
/
blockSize
;
KeScaleSubRegion
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
outputs
,
inputs
,
indices
,
value
,
channel
,
height
,
width
,
nth
);
CHECK_SYNC
(
"ScaleSubRegion"
);
}
__global__
void
KeScaleSubRegionDiff
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
real
value
,
int
channel
,
int
height
,
int
width
,
int
nthreads
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
nthreads
)
{
const
int
w
=
idx
%
width
;
const
int
h
=
(
idx
/
width
)
%
height
;
const
int
c
=
(
idx
/
width
/
height
)
%
channel
;
const
int
n
=
idx
/
width
/
height
/
channel
;
const
int
offset
=
n
*
6
;
if
(
c
>=
(
indices
[
offset
]
-
1
)
&&
c
<=
(
indices
[
offset
+
1
]
-
1
)
&&
h
>=
(
indices
[
offset
+
2
]
-
1
)
&&
h
<=
(
indices
[
offset
+
3
]
-
1
)
&&
w
>=
(
indices
[
offset
+
4
]
-
1
)
&&
w
<=
(
indices
[
offset
+
5
]
-
1
))
{
outGrad
[
idx
]
+=
inGrad
[
idx
]
*
value
;
}
else
{
outGrad
[
idx
]
+=
inGrad
[
idx
];
}
}
}
template
<
>
void
ScaleSubRegionGrad
<
DEVICE_TYPE_GPU
>
(
const
real
*
inGrad
,
real
*
outGrad
,
const
real
*
indices
,
const
TensorShape
shape
,
const
FuncConfig
&
conf
)
{
real
value
=
conf
.
get
<
real
>
(
"value"
);
int
number
=
shape
[
0
];
int
channel
=
shape
[
1
];
int
height
=
shape
[
2
];
int
width
=
shape
[
3
];
size_t
nth
=
number
*
channel
*
height
*
width
;
int
blockSize
=
1024
;
int
gridSize
=
(
nth
+
blockSize
-
1
)
/
blockSize
;
KeScaleSubRegionDiff
<<<
gridSize
,
blockSize
,
0
,
STREAM_DEFAULT
>>>
(
inGrad
,
outGrad
,
indices
,
value
,
channel
,
height
,
width
,
nth
);
CHECK_SYNC
(
"ScaleSubRegionGrad"
);
}
}
// namespace paddle
paddle/function/ScaleSubRegionOpTest.cpp
0 → 100644
浏览文件 @
3b32eb9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace
paddle
{
TEST
(
ScaleSubRegion
,
real
)
{
for
(
size_t
numSamples
:
{
5
,
32
})
{
for
(
size_t
channels
:
{
5
,
32
})
{
for
(
size_t
imgSizeH
:
{
5
,
33
})
{
for
(
size_t
imgSizeW
:
{
5
,
32
})
{
for
(
real
value
:
{
-
0.5
,
0.0
,
0.5
})
{
for
(
bool
firstHalf
:
{
false
,
true
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" channels="
<<
channels
<<
" imgSizeH="
<<
imgSizeH
<<
" imgSizeW="
<<
imgSizeW
;
for
(
bool
testGrad
:
{
false
,
true
})
{
CpuGpuFuncCompare
compare
(
testGrad
?
"ScaleSubRegionGrad"
:
"ScaleSubRegion"
,
FuncConfig
().
set
<
real
>
(
"value"
,
value
));
TensorShape
shape
{
numSamples
,
channels
,
imgSizeH
,
imgSizeW
};
TensorShape
indicesShape
{
numSamples
,
6
};
compare
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
shape
));
compare
.
addInputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
indicesShape
));
compare
.
registerInitCallback
([
=
](
BufferArg
&
arg
,
size_t
index
)
{
if
(
index
==
1
)
{
real
*
data
=
(
real
*
)
arg
.
data
();
for
(
size_t
i
=
0
;
i
<
numSamples
;
++
i
)
{
size_t
offset
=
i
*
6
;
data
[
offset
]
=
firstHalf
?
1
:
channels
/
2
;
data
[
offset
+
1
]
=
firstHalf
?
channels
/
2
:
channels
;
data
[
offset
+
2
]
=
firstHalf
?
1
:
imgSizeH
/
2
;
data
[
offset
+
3
]
=
firstHalf
?
imgSizeH
/
2
:
imgSizeH
;
data
[
offset
+
4
]
=
firstHalf
?
1
:
imgSizeW
/
2
;
data
[
offset
+
5
]
=
firstHalf
?
imgSizeW
/
2
:
imgSizeW
;
}
}
});
compare
.
addOutputs
(
BufferArg
(
VALUE_TYPE_FLOAT
,
shape
,
testGrad
?
ADD_TO
:
ASSIGN_TO
),
testGrad
?
ADD_TO
:
ASSIGN_TO
);
compare
.
run
();
}
}
}
}
}
}
}
}
}
// namespace paddle
paddle/gserver/layers/ScaleSubRegionLayer.cpp
0 → 100644
浏览文件 @
3b32eb9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ScaleSubRegionLayer.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
REGISTER_LAYER
(
scale_sub_region
,
ScaleSubRegionLayer
);
bool
ScaleSubRegionLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
static_cast
<
int
>
(
inputLayers_
.
size
()),
2
);
auto
&
conf
=
config_
.
inputs
(
0
).
scale_sub_region_conf
();
value_
=
conf
.
value
();
createFunction
(
forward_
,
"ScaleSubRegion"
,
FuncConfig
().
set
(
"value"
,
value_
));
createFunction
(
backward_
,
"ScaleSubRegionGrad"
,
FuncConfig
().
set
(
"value"
,
value_
));
return
true
;
}
void
ScaleSubRegionLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
auto
in0
=
getInput
(
0
);
imgH_
=
in0
.
getFrameHeight
();
imgW_
=
in0
.
getFrameWidth
();
if
(
imgH_
==
0
||
imgW_
==
0
)
{
auto
&
conf
=
config_
.
inputs
(
0
).
scale_sub_region_conf
();
imgH_
=
conf
.
image_conf
().
img_size_y
();
imgW_
=
conf
.
image_conf
().
img_size
();
}
MatrixPtr
imgV
=
in0
.
value
;
size_t
batchSize
=
imgV
->
getHeight
();
size_t
spatialSize
=
imgH_
*
imgW_
;
channelsNum_
=
imgV
->
getWidth
()
/
spatialSize
;
shape_
=
TensorShape
({
batchSize
,
channelsNum_
,
imgH_
,
imgW_
});
resetOutput
(
batchSize
,
imgV
->
getWidth
());
auto
out
=
getOutput
();
out
.
setFrameHeight
(
imgH_
);
out
.
setFrameWidth
(
imgW_
);
MatrixPtr
indicesV
=
getInputValue
(
1
);
indicesShape_
=
TensorShape
({
batchSize
,
6
});
REGISTER_TIMER_INFO
(
"ScaleSubRegionForward"
,
getName
().
c_str
());
BufferArgs
inArgs
;
BufferArgs
outArgs
;
inArgs
.
addArg
(
*
imgV
,
shape_
);
inArgs
.
addArg
(
*
indicesV
,
indicesShape_
);
outArgs
.
addArg
(
*
out
.
value
,
shape_
,
ASSIGN_TO
);
forward_
[
0
]
->
calc
(
inArgs
,
outArgs
);
}
void
ScaleSubRegionLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
REGISTER_TIMER_INFO
(
"ScaleSubRegionBackward"
,
getName
().
c_str
());
BufferArgs
inArgs
;
BufferArgs
outArgs
;
inArgs
.
addArg
(
*
getOutputGrad
(),
shape_
);
inArgs
.
addArg
(
*
getInputValue
(
1
),
indicesShape_
);
outArgs
.
addArg
(
*
getInputGrad
(
0
),
shape_
,
ADD_TO
);
backward_
[
0
]
->
calc
(
inArgs
,
outArgs
);
}
}
// namespace paddle
paddle/gserver/layers/ScaleSubRegionLayer.h
0 → 100644
浏览文件 @
3b32eb9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace
paddle
{
/**
* \brief For each instance, this layer can be used to multiply a value to a
* specified sub continuous region. By providing start index and end
* index for C/H/W, you can specify the location and shape of the
* region.
*
* input_0: Input value.
* input_1: Indices value to specify the location an shape of the
* region.
*/
class
ScaleSubRegionLayer
:
public
Layer
{
public:
explicit
ScaleSubRegionLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
~
ScaleSubRegionLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
void
forward
(
PassType
passType
);
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
protected:
TensorShape
shape_
;
TensorShape
indicesShape_
;
size_t
imgH_
;
size_t
imgW_
;
size_t
channelsNum_
;
real
value_
;
};
}
// namespace paddle
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
3b32eb9e
...
@@ -2358,6 +2358,38 @@ TEST(Layer, ScaleShiftLayer) {
...
@@ -2358,6 +2358,38 @@ TEST(Layer, ScaleShiftLayer) {
}
}
}
}
TEST
(
Layer
,
ScaleSubRegionLayer
)
{
const
size_t
batchSize
=
64
;
const
size_t
size
=
4096
;
TestConfig
config
;
config
.
layerConfig
.
set_type
(
"scale_sub_region"
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"input"
,
size
,
0
});
MatrixPtr
indicesV
=
Matrix
::
create
(
batchSize
,
6
,
false
,
false
);
auto
*
data
=
indicesV
->
getData
();
for
(
size_t
i
=
0
;
i
<
batchSize
;
++
i
)
{
data
[
i
*
2
]
=
2
;
data
[
i
*
2
+
1
]
=
4
;
data
[
i
*
2
+
2
]
=
16
;
data
[
i
*
2
+
3
]
=
32
;
data
[
i
*
2
+
4
]
=
16
;
data
[
i
*
2
+
5
]
=
32
;
}
config
.
inputDefs
.
push_back
({
INPUT_SELF_DEFINE_DATA
,
"indices"
,
indicesV
,
{}});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
ScaleSubRegionConfig
*
scaleSubRegionConf
=
input
->
mutable_scale_sub_region_conf
();
ImageConfig
*
imgConf
=
scaleSubRegionConf
->
mutable_image_conf
();
imgConf
->
set_img_size
(
32
);
imgConf
->
set_img_size_y
(
32
);
imgConf
->
set_channels
(
4
);
scaleSubRegionConf
->
set_value
(
2.0
);
config
.
layerConfig
.
add_inputs
();
for
(
auto
useGpu
:
{
false
,
true
})
{
testLayerGrad
(
config
,
"scale_sub_region"
,
batchSize
,
false
,
useGpu
,
false
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
initMain
(
argc
,
argv
);
initMain
(
argc
,
argv
);
...
...
paddle/math/tests/TensorCheck.h
浏览文件 @
3b32eb9e
...
@@ -169,7 +169,7 @@ void TensorCheck(AssertEq compare,
...
@@ -169,7 +169,7 @@ void TensorCheck(AssertEq compare,
count
++
;
count
++
;
}
}
}
}
EXPECT_EQ
(
count
,
0
)
<<
"There are "
<<
count
<<
" different element."
;
EXPECT_EQ
(
count
,
0
)
<<
"There are "
<<
count
<<
" different element
s
."
;
}
}
template
<
typename
AssertEq
,
typename
Tensor1
,
typename
Tensor2
>
template
<
typename
AssertEq
,
typename
Tensor1
,
typename
Tensor2
>
...
...
proto/ModelConfig.proto
浏览文件 @
3b32eb9e
...
@@ -321,6 +321,11 @@ message ClipConfig {
...
@@ -321,6 +321,11 @@ message ClipConfig {
required
double
max
=
2
;
required
double
max
=
2
;
}
}
message
ScaleSubRegionConfig
{
required
ImageConfig
image_conf
=
1
;
required
float
value
=
2
;
}
message
LayerInputConfig
{
message
LayerInputConfig
{
required
string
input_layer_name
=
1
;
required
string
input_layer_name
=
1
;
optional
string
input_parameter_name
=
2
;
optional
string
input_parameter_name
=
2
;
...
@@ -342,6 +347,7 @@ message LayerInputConfig {
...
@@ -342,6 +347,7 @@ message LayerInputConfig {
optional
MultiBoxLossConfig
multibox_loss_conf
=
16
;
optional
MultiBoxLossConfig
multibox_loss_conf
=
16
;
optional
DetectionOutputConfig
detection_output_conf
=
17
;
optional
DetectionOutputConfig
detection_output_conf
=
17
;
optional
ClipConfig
clip_conf
=
18
;
optional
ClipConfig
clip_conf
=
18
;
optional
ScaleSubRegionConfig
scale_sub_region_conf
=
19
;
}
}
message
LayerConfig
{
message
LayerConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
3b32eb9e
...
@@ -3801,6 +3801,25 @@ class SwitchOrderLayer(LayerBase):
...
@@ -3801,6 +3801,25 @@ class SwitchOrderLayer(LayerBase):
self
.
config
.
reshape_conf
.
width_axis
.
extend
(
reshape
[
'width'
])
self
.
config
.
reshape_conf
.
width_axis
.
extend
(
reshape
[
'width'
])
@
config_layer
(
'scale_sub_region'
)
class
ScaleSubRegionLayer
(
LayerBase
):
def
__init__
(
self
,
name
,
inputs
,
value
,
**
xargs
):
super
(
ScaleSubRegionLayer
,
self
).
__init__
(
name
,
'scale_sub_region'
,
0
,
inputs
=
inputs
,
**
xargs
)
scale_sub_region_conf
=
self
.
config
.
inputs
[
0
].
scale_sub_region_conf
scale_sub_region_conf
.
value
=
value
# get channel, width and height from input_0 layer
input_layer
=
self
.
get_input_layer
(
0
)
image_conf
=
scale_sub_region_conf
.
image_conf
image_conf
.
img_size
=
input_layer
.
width
image_conf
.
img_size_y
=
input_layer
.
height
image_conf
.
channels
=
input_layer
.
size
/
(
input_layer
.
width
*
input_layer
.
height
)
self
.
set_cnn_layer
(
name
,
image_conf
.
img_size_y
,
image_conf
.
img_size
,
image_conf
.
channels
)
# Deprecated, use a new layer specific class instead
# Deprecated, use a new layer specific class instead
@
config_func
@
config_func
def
Layer
(
name
,
type
,
**
xargs
):
def
Layer
(
name
,
type
,
**
xargs
):
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
3b32eb9e
...
@@ -144,6 +144,7 @@ __all__ = [
...
@@ -144,6 +144,7 @@ __all__ = [
'img_conv3d_layer'
,
'img_conv3d_layer'
,
'resize_layer'
,
'resize_layer'
,
'sub_seq_layer'
,
'sub_seq_layer'
,
'scale_sub_region_layer'
,
]
]
...
@@ -255,6 +256,8 @@ class LayerType(object):
...
@@ -255,6 +256,8 @@ class LayerType(object):
RESIZE
=
'resize'
RESIZE
=
'resize'
SUB_SEQ_LAYER
=
'subseq'
SUB_SEQ_LAYER
=
'subseq'
SCALE_SUB_REGION_LAYER
=
'scale_sub_region'
@
staticmethod
@
staticmethod
def
is_layer_type
(
type_name
):
def
is_layer_type
(
type_name
):
"""
"""
...
@@ -7042,3 +7045,54 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None):
...
@@ -7042,3 +7045,54 @@ def sub_seq_layer(input, offsets, sizes, act=None, bias_attr=None, name=None):
LayerType
.
SUB_SEQ_LAYER
,
LayerType
.
SUB_SEQ_LAYER
,
parents
=
[
input
,
offsets
,
sizes
],
parents
=
[
input
,
offsets
,
sizes
],
size
=
input
.
size
)
size
=
input
.
size
)
@
wrap_name_default
(
'scale_sub_region'
)
def
scale_sub_region_layer
(
input
,
indices
,
value
,
name
=
None
):
"""
Given an image or feature map with CHW information, scale_sub_region_layer
can be used to multiply a real value to values of a sub continuous region.
You can provide start and end indices of CHW for each instance.
Please notice that all start indices are counting from 1.
The shape of indices should be [batch_size, 6] and the layout for each row
is [C_Start, C_End, H_Start, H_End, W_Start, W_End].
.. code-block:: python
scale_sub_region = scale_sub_region_layer(input=input,
indices=indices,
value=value)
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: The input of this layer which should contains CHW information.
:type input: LayerOutput
:param indices: Start index and end index for C H W, the input value should
be a 2-D matrix with shape [batch_size, 6].
:type indices: LayerOutput.
:param value: value to multiply.
:type value: float
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert
isinstance
(
input
,
LayerOutput
),
(
'The first input of scale_sub_region_layer, '
'must be a PaddlePaddle layer.'
)
assert
isinstance
(
indices
,
LayerOutput
),
(
'The start and end indices for CHW, must be a PaddlePaddle layer.'
)
assert
isinstance
(
value
,
float
),
(
'The value to multiply, must be a real value.'
)
Layer
(
name
=
name
,
type
=
LayerType
.
SCALE_SUB_REGION_LAYER
,
inputs
=
[
input
.
name
,
indices
.
name
],
value
=
value
)
return
LayerOutput
(
name
,
LayerType
.
SCALE_SUB_REGION_LAYER
,
parents
=
[
input
,
indices
],
num_filters
=
input
.
num_filters
,
size
=
input
.
size
)
python/paddle/trainer_config_helpers/tests/configs/file_list.sh
浏览文件 @
3b32eb9e
...
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
...
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
)
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
test_scale_sub_region_layer
)
export
whole_configs
=(
test_split_datasource
)
export
whole_configs
=(
test_split_datasource
)
python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_sub_region_layer.protostr
0 → 100644
浏览文件 @
3b32eb9e
type: "nn"
layers {
name: "data"
type: "data"
size: 2016
active_type: ""
height: 48
width: 42
}
layers {
name: "indices"
type: "data"
size: 6
active_type: ""
}
layers {
name: "__scale_sub_region_0__"
type: "scale_sub_region"
size: 2016
active_type: ""
inputs {
input_layer_name: "data"
scale_sub_region_conf {
image_conf {
channels: 1
img_size: 42
img_size_y: 48
}
value: 0.0
}
}
inputs {
input_layer_name: "indices"
}
height: 48
width: 42
}
input_layer_names: "data"
input_layer_names: "indices"
output_layer_names: "__scale_sub_region_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "indices"
layer_names: "__scale_sub_region_0__"
input_layer_names: "data"
input_layer_names: "indices"
output_layer_names: "__scale_sub_region_0__"
is_recurrent_layer_group: false
}
python/paddle/trainer_config_helpers/tests/configs/test_scale_sub_region_layer.py
0 → 100644
浏览文件 @
3b32eb9e
from
paddle.trainer_config_helpers
import
*
settings
(
batch_size
=
1000
,
learning_rate
=
1e-5
)
data
=
data_layer
(
name
=
'data'
,
size
=
2016
,
height
=
48
,
width
=
42
)
indices
=
data_layer
(
name
=
'indices'
,
size
=
6
)
scale_sub_region
=
scale_sub_region_layer
(
input
=
data
,
indices
=
indices
,
value
=
0.0
)
outputs
(
scale_sub_region
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录