Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
bbe441fc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bbe441fc
编写于
2月 24, 2022
作者:
zhouweiwei2014
提交者:
GitHub
2月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Phi】Migrate poisson op into phi (#39814)
* Migrate poisson op into phi * fix CI * fix comment
上级
23bbd912
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
270 addition
and
42 deletion
+270
-42
paddle/fluid/operators/poisson_op.cc
paddle/fluid/operators/poisson_op.cc
+9
-42
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+1
-0
paddle/phi/kernels/cpu/poisson_grad_kernel.cc
paddle/phi/kernels/cpu/poisson_grad_kernel.cc
+19
-0
paddle/phi/kernels/cpu/poisson_kernel.cc
paddle/phi/kernels/cpu/poisson_kernel.cc
+41
-0
paddle/phi/kernels/gpu/poisson_grad_kernel.cu
paddle/phi/kernels/gpu/poisson_grad_kernel.cu
+19
-0
paddle/phi/kernels/gpu/poisson_kernel.cu
paddle/phi/kernels/gpu/poisson_kernel.cu
+77
-0
paddle/phi/kernels/impl/poisson_grad_kernel_impl.h
paddle/phi/kernels/impl/poisson_grad_kernel_impl.h
+29
-0
paddle/phi/kernels/poisson_grad_kernel.h
paddle/phi/kernels/poisson_grad_kernel.h
+25
-0
paddle/phi/kernels/poisson_kernel.h
paddle/phi/kernels/poisson_kernel.h
+24
-0
paddle/phi/ops/compat/poisson_sig.cc
paddle/phi/ops/compat/poisson_sig.cc
+26
-0
未找到文件。
paddle/fluid/operators/poisson_op.cc
浏览文件 @
bbe441fc
...
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <string>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/operators/poisson_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -23,14 +25,6 @@ class PoissonOp : public framework::OperatorWithKernel {
...
@@ -23,14 +25,6 @@ class PoissonOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"PoissonOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"PoissonOp"
);
auto
dim
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
"Out"
,
dim
);
}
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -61,29 +55,6 @@ class PoissonOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
...
@@ -61,29 +55,6 @@ class PoissonOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
}
}
};
};
template
<
typename
T
>
class
PoissonKernel
<
platform
::
CPUDeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
size
=
x
->
numel
();
auto
gen
=
framework
::
DefaultCPUGenerator
();
auto
engine
=
gen
->
GetCPUEngine
();
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
std
::
poisson_distribution
<>
dist
(
x_data
[
i
]);
out_data
[
i
]
=
static_cast
<
T
>
(
dist
(
*
engine
));
}
}
};
class
PoissonGradOp
:
public
framework
::
OperatorWithKernel
{
class
PoissonGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -116,17 +87,13 @@ class PoissonGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -116,17 +87,13 @@ class PoissonGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
DELCARE_INFER_SHAPE_FUNCTOR
(
poisson
,
PoissonInferShapeFunctor
,
PT_INFER_META
(
phi
::
UnchangedInferMeta
));
REGISTER_OPERATOR
(
poisson
,
ops
::
PoissonOp
,
ops
::
PoissonOpMaker
,
REGISTER_OPERATOR
(
poisson
,
ops
::
PoissonOp
,
ops
::
PoissonOpMaker
,
ops
::
PoissonOpInferVarType
,
ops
::
PoissonOpInferVarType
,
ops
::
PoissonGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
PoissonGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
PoissonGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
PoissonGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
PoissonInferShapeFunctor
);
REGISTER_OPERATOR
(
poisson_grad
,
ops
::
PoissonGradOp
);
REGISTER_OPERATOR
(
poisson_grad
,
ops
::
PoissonGradOp
);
REGISTER_OP_CPU_KERNEL
(
poisson
,
ops
::
PoissonKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
PoissonKernel
<
plat
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
poisson_grad
,
ops
::
PoissonGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
PoissonGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
paddle/phi/infermeta/unary.h
浏览文件 @
bbe441fc
...
@@ -103,4 +103,5 @@ void UnfoldInferMeta(const MetaTensor& x,
...
@@ -103,4 +103,5 @@ void UnfoldInferMeta(const MetaTensor& x,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
MetaTensor
*
out
,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
MetaConfig
config
=
MetaConfig
());
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/cpu/poisson_grad_kernel.cc
0 → 100644
浏览文件 @
bbe441fc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/poisson_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
poisson_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
PoissonGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/cpu/poisson_kernel.cc
0 → 100644
浏览文件 @
bbe441fc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/poisson_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
PoissonKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
const
T
*
x_data
=
x
.
data
<
T
>
();
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
int64_t
size
=
x
.
numel
();
auto
gen
=
ctx
.
GetGenerator
();
auto
engine
=
gen
->
GetCPUEngine
();
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
std
::
poisson_distribution
<>
dist
(
x_data
[
i
]);
out_data
[
i
]
=
static_cast
<
T
>
(
dist
(
*
engine
));
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
poisson
,
CPU
,
ALL_LAYOUT
,
phi
::
PoissonKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/poisson_grad_kernel.cu
0 → 100644
浏览文件 @
bbe441fc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/poisson_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
poisson_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
PoissonGradKernel
,
float
,
double
)
{}
paddle/
fluid/operators/poisson_op
.cu
→
paddle/
phi/kernels/gpu/poisson_kernel
.cu
浏览文件 @
bbe441fc
/* Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -18,16 +18,20 @@ limitations under the License. */
...
@@ -18,16 +18,20 @@ limitations under the License. */
#ifdef __HIPCC__
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#include <hiprand_kernel.h>
#endif
#endif
#include "paddle/fluid/operators/poisson_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/poisson_kernel.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
struct
PoissonCudaFunctor
{
struct
PoissonCudaFunctor
{
public:
public:
PoissonCudaFunctor
(
const
T
*
in
,
T
*
out
,
unsigned
int
seed
,
PoissonCudaFunctor
(
const
T
*
in
,
T
*
out
,
unsigned
int
seed
,
unsigned
int
offset
)
unsigned
int
offset
)
:
in_
(
in
),
out_
(
out
),
seed_
(
seed
),
offset_
(
offset
)
{}
:
in_
(
in
),
out_
(
out
),
seed_
(
seed
),
offset_
(
offset
)
{}
...
@@ -50,42 +54,24 @@ struct PoissonCudaFunctor {
...
@@ -50,42 +54,24 @@ struct PoissonCudaFunctor {
const
unsigned
int
offset_
;
const
unsigned
int
offset_
;
};
};
template
<
typename
T
>
template
<
typename
T
,
typename
Context
>
class
PoissonKernel
<
platform
::
CUDADeviceContext
,
T
>
void
PoissonKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
:
public
framework
::
OpKernel
<
T
>
{
const
T
*
x_data
=
x
.
data
<
T
>
();
public:
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
size
=
x
.
numel
();
const
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
size
=
x
->
numel
();
int64_t
device_id
=
ctx
.
GetPlace
().
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
auto
gen_cuda
=
ctx
.
GetGenerator
();
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
20
);
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
20
);
uint64_t
seed
=
seed_offset
.
first
;
uint64_t
seed
=
seed_offset
.
first
;
uint64_t
offset
=
seed_offset
.
second
;
uint64_t
offset
=
seed_offset
.
second
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
size
);
PoissonCudaFunctor
<
T
>
functor
(
x_data
,
out_data
,
seed
,
offset
);
for_range
(
functor
);
}
};
}
// namespace operators
paddle
::
platform
::
ForRange
<
Context
>
for_range
(
ctx
,
size
);
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
PoissonCudaFunctor
<
T
>
functor
(
x_data
,
out_data
,
seed
,
offset
);
namespace
plat
=
paddle
::
platform
;
for_range
(
functor
);
}
REGISTER_OP_CUDA_KERNEL
(
poisson
,
}
// namespace phi
ops
::
PoissonKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
PoissonKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
PD_REGISTER_KERNEL
(
poisson_grad
,
ops
::
PoissonGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
poisson
,
GPU
,
ALL_LAYOUT
,
phi
::
PoissonKernel
,
float
,
double
)
{}
ops
::
PoissonGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/
fluid/operators/poisson_op
.h
→
paddle/
phi/kernels/impl/poisson_grad_kernel_impl
.h
浏览文件 @
bbe441fc
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -14,28 +14,16 @@
...
@@ -14,28 +14,16 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/poisson_grad_kernel.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
Context
>
class
PoissonKernel
;
void
PoissonGradKernel
(
const
Context
&
ctx
,
DenseTensor
*
x_grad
)
{
ctx
.
template
Alloc
<
T
>(
x_grad
);
phi
::
funcs
::
SetConstant
<
Context
,
T
>
functor
;
functor
(
ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
}
template
<
typename
DeviceContext
,
typename
T
>
}
// namespace phi
class
PoissonGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
functor
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
functor
(
dev_ctx
,
dx
,
static_cast
<
T
>
(
0
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/kernels/poisson_grad_kernel.h
0 → 100644
浏览文件 @
bbe441fc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
PoissonGradKernel
(
const
Context
&
ctx
,
DenseTensor
*
x_grad
);
}
// namespace phi
paddle/phi/kernels/poisson_kernel.h
0 → 100644
浏览文件 @
bbe441fc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
PoissonKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/poisson_sig.cc
0 → 100644
浏览文件 @
bbe441fc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
PoissonGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"poisson_grad"
,
{},
{},
{
GradVarName
(
"X"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
poisson_grad
,
phi
::
PoissonGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录