Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0ce558f1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0ce558f1
编写于
3月 28, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
kernels of increment op
上级
1b67bc02
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
90 addition
and
59 deletion
+90
-59
paddle/fluid/operators/increment_op.cc
paddle/fluid/operators/increment_op.cc
+30
-59
paddle/fluid/operators/increment_op.cu
paddle/fluid/operators/increment_op.cu
+21
-0
paddle/fluid/operators/increment_op.h
paddle/fluid/operators/increment_op.h
+39
-0
未找到文件。
paddle/fluid/operators/increment_op.cc
浏览文件 @
0ce558f1
/
* Copyright (c) 2016
PaddlePaddle Authors. All Rights Reserved.
/
/ Copyright (c) 2018
PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License");
//
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
//
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
//
You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0
//
http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software
//
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
//
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
//
See the License for the specific language governing permissions and
limitations under the License. */
// limitations under the License.
#include "paddle/fluid/
framework/op_registry
.h"
#include "paddle/fluid/
operators/increment_op
.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
Increment
InferShape
:
public
framework
::
InferShapeBase
{
class
Increment
Op
:
public
framework
::
OperatorWithKernel
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
IncrementOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of IncrementOp should not be null."
);
"Input(X) of IncrementOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of IncrementOp should not be null."
);
"Output(Out) of IncrementOp should not be null."
);
PADDLE_ENFORCE_EQ
(
1
,
framework
::
product
(
ctx
->
GetInputDim
(
"X"
)));
PADDLE_ENFORCE_EQ
(
1
,
framework
::
product
(
ctx
->
GetInputDim
(
"X"
)));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
}
ctx
->
ShareLoD
(
"X"
,
"Out"
);
};
struct
IncrementFunctor
{
IncrementFunctor
(
const
framework
::
LoDTensor
&
x
,
framework
::
LoDTensor
*
out
,
float
value
)
:
x_
(
x
),
out_
(
out
),
value_
(
value
)
{}
template
<
typename
T
>
void
operator
()()
const
{
*
out_
->
data
<
T
>
()
=
*
x_
.
data
<
T
>
()
+
static_cast
<
T
>
(
value_
);
}
const
framework
::
LoDTensor
&
x_
;
framework
::
LoDTensor
*
out_
;
float
value_
;
};
class
IncrementOp
:
public
framework
::
OperatorBase
{
public:
IncrementOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
x
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
();
auto
&
out
=
*
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
x
.
place
()));
out
.
Resize
(
x
.
dims
());
out
.
mutable_data
(
x
.
place
(),
x
.
type
());
float
value
=
Attr
<
float
>
(
"step"
);
VLOG
(
10
)
<<
Output
(
"Out"
)
<<
" increase "
<<
Input
(
"X"
)
<<
" with "
<<
value
;
framework
::
VisitDataType
(
framework
::
ToDataType
(
out
.
type
()),
IncrementFunctor
(
x
,
&
out
,
value
));
}
}
};
};
...
@@ -108,5 +74,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
...
@@ -108,5 +74,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
increment
,
ops
::
IncrementOp
,
ops
::
IncrementInferShape
,
REGISTER_OPERATOR
(
increment
,
ops
::
IncrementOp
,
ops
::
IncrementOpMaker
,
ops
::
IncrementOpMaker
,
ops
::
IncrementGradOpMaker
);
ops
::
IncrementGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
increment
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
)
paddle/fluid/operators/increment_op.cu
0 → 100644
浏览文件 @
0ce558f1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/minus_op.h"
REGISTER_OP_CUDA_KERNEL
(
increment
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
IncrementKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
)
\ No newline at end of file
paddle/fluid/operators/increment_op.h
0 → 100644
浏览文件 @
0ce558f1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
IncrementKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x_tensor
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
float
step
=
context
.
Attr
<
float
>
(
"step"
);
out_tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
framework
::
EigenScalar
<
T
>::
From
(
*
out_tensor
).
device
(
dev
)
=
framework
::
EigenScalar
<
T
>::
From
(
*
x_tensor
)
+
static_cast
<
T
>
(
step
);
}
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录