Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
85a41df3
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
85a41df3
编写于
5月 16, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Init commit
上级
56744092
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
226 addition
and
0 deletion
+226
-0
paddle/fluid/operators/random_crop_op.cc
paddle/fluid/operators/random_crop_op.cc
+59
-0
paddle/fluid/operators/random_crop_op.h
paddle/fluid/operators/random_crop_op.h
+167
-0
未找到文件。
paddle/fluid/operators/random_crop_op.cc
0 → 100644
浏览文件 @
85a41df3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/random_crop_op.h"
#include <vector>
namespace
paddle
{
namespace
operators
{
class
RandomCropOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
""
);
AddOutput
(
"Y"
,
""
);
AddInput
(
"Seed"
,
""
);
AddOutput
(
"SeedOut"
,
""
).
AsDispensable
();
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
""
);
}
};
class
RandomCropOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
auto
shape
=
context
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
auto
x_dim
=
context
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dim
.
size
(),
static_cast
<
int64_t
>
(
shape
.
size
()));
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
-
1
)
{
shape
[
i
]
=
static_cast
<
int
>
(
x_dim
[
i
]);
}
else
{
PADDLE_ENFORCE_GE
(
x_dim
[
i
],
shape
[
i
]);
}
}
context
->
SetOutputDim
(
"Y"
,
framework
::
make_ddim
(
shape
));
context
->
SetOutputDim
(
"SeedOut"
,
framework
::
make_ddim
({
1
}));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
f
=
paddle
::
framework
;
REGISTER_OPERATOR
(
random_crop
,
f
::
OperatorWithKernel
,
ops
::
RandomCropOpMaker
,
ops
::
RandomCropOpInferShape
);
template
<
typename
T
>
using
Kernel
=
ops
::
RandomCropKernel
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
;
REGISTER_OP_CPU_KERNEL
(
random_crop
,
Kernel
<
float
>
,
Kernel
<
int
>
,
Kernel
<
double
>
,
Kernel
<
uint8_t
>
,
Kernel
<
int16_t
>
);
paddle/fluid/operators/random_crop_op.h
0 → 100644
浏览文件 @
85a41df3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "thrust/random.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
>
struct
Random
;
template
<
>
struct
Random
<
platform
::
CPUDeviceContext
>
{
using
Engine
=
std
::
minstd_rand
;
template
<
typename
T
>
using
UniformIntDist
=
std
::
uniform_int_distribution
<
T
>
;
};
template
<
>
struct
Random
<
platform
::
CUDADeviceContext
>
{
using
Engine
=
thrust
::
minstd_rand
;
template
<
typename
T
>
using
UniformIntDist
=
thrust
::
uniform_int_distribution
<
T
>
;
};
template
<
typename
T
>
HOSTDEVICE
inline
void
RandomCropImpl
(
const
T
*
x
,
size_t
*
x_dim
,
T
*
out
,
size_t
*
out_dim
,
int
i
,
int
rank
,
int64_t
prod_x_remain
,
int64_t
prod_out_remain
,
size_t
*
offset
)
{
size_t
x_length
=
x_dim
[
rank
];
size_t
out_length
=
out_dim
[
rank
];
int64_t
x_stride
=
prod_x_remain
/
x_length
;
int64_t
out_stride
=
prod_out_remain
/
out_length
;
size_t
offset_i
=
offset
[
i
];
if
(
x_stride
==
1
&&
out_stride
==
1
)
{
// In the final stage, copy from offset.
x
+=
offset_i
;
for
(
size_t
i
=
0
;
i
<
out_length
;
++
i
)
{
*
out
++
=
*
x
++
;
}
}
else
{
x
+=
offset_i
*
x_stride
;
for
(
size_t
i
=
0
;
i
<
out_length
;
++
i
)
{
RandomCropImpl
<
T
>
(
x
,
x_dim
,
out
,
out_dim
,
i
+
1
,
rank
,
x_stride
,
out_stride
,
offset
);
x
+=
x_stride
;
out
+=
out_stride
;
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
struct
RandomCropFunctor
{
const
T
*
x_
;
T
*
out_
;
size_t
x_dim_
[
9
];
size_t
out_dim_
[
9
];
size_t
prod_same_dim_
;
size_t
prod_x_dim_
;
size_t
prod_out_dim_
;
int
num_same_dim_
;
int
rank_
;
int64_t
seed_
;
RandomCropFunctor
(
const
T
*
x
,
T
*
out
,
int64_t
seed
)
:
x_
(
x
),
out_
(
out
),
prod_same_dim_
(
1
),
prod_x_dim_
(
1
),
prod_out_dim_
(
1
),
seed_
(
seed
)
{
std
::
fill
(
x_dim_
,
x_dim_
+
sizeof
(
x_dim_
)
/
sizeof
(
size_t
),
0
);
std
::
fill
(
out_dim_
,
out_dim_
+
sizeof
(
out_dim_
)
/
sizeof
(
size_t
),
0
);
}
HOSTDEVICE
void
operator
()(
size_t
i
)
{
typename
Random
<
DeviceContext
>::
Engine
engine
(
seed_
);
engine
.
discard
(
i
*
(
rank_
-
num_same_dim_
));
int64_t
prod_x_unsame
=
(
prod_x_dim_
/
prod_same_dim_
);
int64_t
prod_out_unsame
=
(
prod_out_dim_
/
prod_same_dim_
);
const
T
*
x
=
x_
+
i
*
prod_x_unsame
;
T
*
out
=
out_
+
i
*
prod_out_unsame
;
size_t
offset
[
9
];
for
(
int
i
=
num_same_dim_
;
i
<
rank_
;
++
i
)
{
typename
Random
<
DeviceContext
>::
template
UniformIntDist
<
size_t
>
dist
(
0
,
x_dim_
[
i
]
-
out_dim_
[
i
]);
offset
[
i
]
=
dist
(
engine
);
}
RandomCropImpl
<
T
>
(
x
,
x_dim_
,
out
,
out_dim_
,
num_same_dim_
,
rank_
,
prod_x_unsame
,
prod_out_unsame
,
offset
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
RandomCropKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
int64_t
seed
=
*
context
.
Input
<
framework
::
LoDTensor
>
(
"Seed"
)
->
data
<
int64_t
>
();
auto
&
x
=
detail
::
Ref
(
context
.
Input
<
framework
::
LoDTensor
>
(
"X"
));
auto
&
out
=
detail
::
Ref
(
context
.
Output
<
framework
::
LoDTensor
>
(
"Out"
));
RandomCropFunctor
<
DeviceContext
,
T
>
functor
{
x
.
data
<
T
>
(),
out
.
mutable_data
<
T
>
(
context
.
GetPlace
()),
seed
};
auto
&
out_dim
=
out
.
dims
();
auto
&
x_dim
=
x
.
dims
();
auto
rank
=
x_dim
.
size
();
while
(
rank
--
>
0
)
{
functor
.
x_dim_
[
rank
]
=
x_dim
[
rank
];
functor
.
out_dim_
[
rank
]
=
out_dim
[
rank
];
functor
.
prod_x_dim_
*=
x_dim
[
rank
];
functor
.
prod_out_dim_
*=
out_dim
[
rank
];
if
(
x_dim
[
rank
]
!=
out_dim
[
rank
])
{
PADDLE_ENFORCE_EQ
(
functor
.
prod_same_dim_
,
1
);
functor
.
num_same_dim_
=
rank
;
}
else
{
functor
.
prod_same_dim_
*=
out_dim
[
rank
];
}
}
functor
.
rank_
=
x_dim
.
size
();
platform
::
ForRange
<
DeviceContext
>
for_range
(
context
.
template
device_context
<
DeviceContext
>(),
functor
.
prod_same_dim_
);
for_range
(
functor
);
Random
<
platform
::
CPUDeviceContext
>::
Engine
engine
(
seed
);
engine
.
discard
(
functor
.
prod_same_dim_
*
(
functor
.
rank_
-
functor
.
num_same_dim_
));
*
context
.
Output
<
framework
::
LoDTensor
>
(
"SeedOut"
)
->
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
())
=
engine
();
}
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录