Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b8f7edc3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b8f7edc3
编写于
2月 14, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix format error; test=develop
上级
5076ad87
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
175 addition
and
145 deletion
+175
-145
paddle/fluid/imperative/prepared_operator.h
paddle/fluid/imperative/prepared_operator.h
+1
-1
paddle/fluid/operators/histogram_op.cc
paddle/fluid/operators/histogram_op.cc
+2
-7
paddle/pten/kernels/cpu/histogram_kernel.cc
paddle/pten/kernels/cpu/histogram_kernel.cc
+62
-51
paddle/pten/kernels/gpu/histogram_kernel.cu
paddle/pten/kernels/gpu/histogram_kernel.cu
+95
-81
paddle/pten/ops/compat/histogram_sig.cc
paddle/pten/ops/compat/histogram_sig.cc
+15
-5
未找到文件。
paddle/fluid/imperative/prepared_operator.h
浏览文件 @
b8f7edc3
...
@@ -412,7 +412,7 @@ void BuildDygraphPtenKernelContext(
...
@@ -412,7 +412,7 @@ void BuildDygraphPtenKernelContext(
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
float
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
float
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
bool
)))
{
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
bool
)))
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
bool
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
bool
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
int64_t
)))
{
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
int64_t
)))
{
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
int64_t
,
attr
));
kernel_ctx
->
EmplaceBackAttr
(
BOOST_GET_CONST
(
int64_t
,
attr
));
}
else
if
(
attr_defs
[
i
].
type_index
==
}
else
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
std
::
string
)))
{
std
::
type_index
(
typeid
(
std
::
string
)))
{
...
...
paddle/fluid/operators/histogram_op.cc
浏览文件 @
b8f7edc3
...
@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -85,8 +85,3 @@ REGISTER_OPERATOR(
...
@@ -85,8 +85,3 @@ REGISTER_OPERATOR(
histogram
,
ops
::
HistogramOp
,
ops
::
HistogramOpMaker
,
histogram
,
ops
::
HistogramOp
,
ops
::
HistogramOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
// REGISTER_OP_CPU_KERNEL(
// histogram, ops::HistogramKernel<paddle::platform::CPUDeviceContext, float>,
// ops::HistogramKernel<paddle::platform::CPUDeviceContext, double>,
// ops::HistogramKernel<paddle::platform::CPUDeviceContext, int>,
// ops::HistogramKernel<paddle::platform::CPUDeviceContext, int64_t>);
paddle/pten/kernels/cpu/histogram_kernel.cc
浏览文件 @
b8f7edc3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/funcs/math_function.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/funcs/math_function.h"
namespace
pten
{
namespace
pten
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
HistogramKernel
(
const
Context
&
dev_ctx
,
void
HistogramKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
input
,
int64_t
bins
,
int64_t
bins
,
int
min
,
int
min
,
int
max
,
int
max
,
DenseTensor
*
output
)
DenseTensor
*
output
)
{
{
auto
&
nbins
=
bins
;
auto
&
nbins
=
bins
;
auto
&
minval
=
min
;
auto
&
minval
=
min
;
auto
&
maxval
=
max
;
auto
&
maxval
=
max
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
input_data
=
input
.
data
<
T
>
();
auto
input_numel
=
input
.
numel
();
auto
input_numel
=
input
.
numel
();
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
dev_ctx
.
GetPlace
());
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
dev_ctx
.
GetPlace
());
pten
::
funcs
::
SetConstant
<
Context
,
int64_t
>
()(
pten
::
funcs
::
SetConstant
<
Context
,
int64_t
>
()(
dev_ctx
,
output
,
dev_ctx
,
output
,
static_cast
<
int64_t
>
(
0
));
static_cast
<
int64_t
>
(
0
));
if
(
input_data
==
nullptr
)
return
;
if
(
input_data
==
nullptr
)
return
;
T
output_min
=
static_cast
<
T
>
(
minval
);
T
output_min
=
static_cast
<
T
>
(
minval
);
T
output_max
=
static_cast
<
T
>
(
maxval
);
T
output_max
=
static_cast
<
T
>
(
maxval
);
if
(
output_min
==
output_max
)
{
if
(
output_min
==
output_max
)
{
output_min
=
*
std
::
min_element
(
input_data
,
input_data
+
input_numel
);
output_min
=
*
std
::
min_element
(
input_data
,
input_data
+
input_numel
);
output_max
=
*
std
::
max_element
(
input_data
,
input_data
+
input_numel
);
output_max
=
*
std
::
max_element
(
input_data
,
input_data
+
input_numel
);
}
}
if
(
output_min
==
output_max
)
{
if
(
output_min
==
output_max
)
{
output_min
=
output_min
-
1
;
output_min
=
output_min
-
1
;
output_max
=
output_max
+
1
;
output_max
=
output_max
+
1
;
}
}
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
(
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))
||
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))),
std
::
isnan
(
static_cast
<
float
>
(
output_max
))),
false
,
pten
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
false
,
PADDLE_ENFORCE_GE
(
pten
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
output_max
,
output_min
,
PADDLE_ENFORCE_GE
(
pten
::
errors
::
InvalidArgument
(
output_max
,
"max must be larger or equal to min. If min and max are both zero, "
output_min
,
"the minimum and maximum values of the data are used. "
pten
::
errors
::
InvalidArgument
(
"But received max is %d, min is %d"
,
"max must be larger or equal to min. If min and max are both zero, "
maxval
,
minval
));
"the minimum and maximum values of the data are used. "
"But received max is %d, min is %d"
,
maxval
,
minval
));
for
(
int64_t
i
=
0
;
i
<
input_numel
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
input_numel
;
i
++
)
{
if
(
input_data
[
i
]
>=
output_min
&&
input_data
[
i
]
<=
output_max
)
{
if
(
input_data
[
i
]
>=
output_min
&&
input_data
[
i
]
<=
output_max
)
{
const
int64_t
bin
=
(
int64_t
)((
input_data
[
i
]
-
output_min
)
*
nbins
/
const
int64_t
bin
=
(
int64_t
)((
input_data
[
i
]
-
output_min
)
*
nbins
/
(
output_max
-
output_min
));
(
output_max
-
output_min
));
out_data
[
std
::
min
(
bin
,
nbins
-
1
)]
+=
1
;
out_data
[
std
::
min
(
bin
,
nbins
-
1
)]
+=
1
;
}
}
}
}
}
}
}
// namspace pten
}
// namespace pten
PT_REGISTER_KERNEL
(
histogram
,
PT_REGISTER_KERNEL
(
histogram
,
CPU
,
CPU
,
...
@@ -74,4 +85,4 @@ PT_REGISTER_KERNEL(histogram,
...
@@ -74,4 +85,4 @@ PT_REGISTER_KERNEL(histogram,
float
,
float
,
double
,
double
,
int
,
int
,
int64_t
)
{}
int64_t
)
{}
\ No newline at end of file
paddle/pten/kernels/gpu/histogram_kernel.cu
浏览文件 @
b8f7edc3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/kernels/funcs/math_function.h"
#include "paddle/pten/kernels/funcs/math_function.h"
#include "paddle/pten/kernels/histogram_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/kernel_registry.h"
...
@@ -9,8 +21,8 @@
...
@@ -9,8 +21,8 @@
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
namespace
pten
{
namespace
pten
{
...
@@ -22,7 +34,9 @@ inline int GET_BLOCKS(const int N) {
...
@@ -22,7 +34,9 @@ inline int GET_BLOCKS(const int N) {
}
}
template
<
typename
T
,
typename
IndexType
>
template
<
typename
T
,
typename
IndexType
>
__device__
static
IndexType
GetBin
(
T
input_value
,
T
min_value
,
T
max_value
,
__device__
static
IndexType
GetBin
(
T
input_value
,
T
min_value
,
T
max_value
,
int64_t
nbins
)
{
int64_t
nbins
)
{
IndexType
bin
=
static_cast
<
int
>
((
input_value
-
min_value
)
*
nbins
/
IndexType
bin
=
static_cast
<
int
>
((
input_value
-
min_value
)
*
nbins
/
(
max_value
-
min_value
));
(
max_value
-
min_value
));
...
@@ -31,9 +45,12 @@ __device__ static IndexType GetBin(T input_value, T min_value, T max_value,
...
@@ -31,9 +45,12 @@ __device__ static IndexType GetBin(T input_value, T min_value, T max_value,
}
}
template
<
typename
T
,
typename
IndexType
>
template
<
typename
T
,
typename
IndexType
>
__global__
void
KernelHistogram
(
const
T
*
input
,
const
int
total_elements
,
__global__
void
KernelHistogram
(
const
T
*
input
,
const
int64_t
nbins
,
const
T
min_value
,
const
int
total_elements
,
const
T
max_value
,
int64_t
*
output
)
{
const
int64_t
nbins
,
const
T
min_value
,
const
T
max_value
,
int64_t
*
output
)
{
extern
__shared__
int64_t
buf_hist
[];
extern
__shared__
int64_t
buf_hist
[];
for
(
int
i
=
threadIdx
.
x
;
i
<
nbins
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nbins
;
i
+=
blockDim
.
x
)
{
buf_hist
[
i
]
=
0
;
buf_hist
[
i
]
=
0
;
...
@@ -58,83 +75,80 @@ __global__ void KernelHistogram(const T* input, const int total_elements,
...
@@ -58,83 +75,80 @@ __global__ void KernelHistogram(const T* input, const int total_elements,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
HistogramKernel
(
const
Context
&
dev_ctx
,
void
HistogramKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
input
,
int64_t
bins
,
int64_t
bins
,
int
min
,
int
min
,
int
max
,
int
max
,
DenseTensor
*
output
)
DenseTensor
*
output
)
{
{
auto
&
nbins
=
bins
;
auto
&
nbins
=
bins
;
auto
&
minval
=
min
;
auto
&
minval
=
min
;
auto
&
maxval
=
max
;
auto
&
maxval
=
max
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
input_data
=
input
.
data
<
T
>
();
const
int
input_numel
=
input
.
numel
();
const
int
input_numel
=
input
.
numel
();
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
dev_ctx
.
GetPlace
());
int64_t
*
out_data
=
output
->
mutable_data
<
int64_t
>
(
dev_ctx
.
GetPlace
());
pten
::
funcs
::
SetConstant
<
Context
,
int64_t
>
()(
pten
::
funcs
::
SetConstant
<
Context
,
int64_t
>
()(
dev_ctx
,
output
,
static_cast
<
int64_t
>
(
0
));
dev_ctx
,
output
,
static_cast
<
int64_t
>
(
0
));
if
(
input_data
==
nullptr
)
return
;
if
(
input_data
==
nullptr
)
return
;
T
output_min
=
static_cast
<
T
>
(
minval
);
T
output_max
=
static_cast
<
T
>
(
maxval
);
T
output_min
=
static_cast
<
T
>
(
minval
);
T
output_max
=
static_cast
<
T
>
(
maxval
);
if
(
output_min
==
output_max
)
{
auto
input_x
=
pten
::
EigenVector
<
T
>::
Flatten
(
input
);
if
(
output_min
==
output_max
)
{
auto
input_x
=
pten
::
EigenVector
<
T
>::
Flatten
(
input
);
DenseTensor
input_min_t
,
input_max_t
;
auto
*
input_min_data
=
input_min_t
.
mutable_data
<
T
>
({
1
},
dev_ctx
.
GetPlace
());
DenseTensor
input_min_t
,
input_max_t
;
auto
*
input_max_data
=
input_max_t
.
mutable_data
<
T
>
({
1
},
dev_ctx
.
GetPlace
());
auto
*
input_min_data
=
auto
input_min_scala
=
pten
::
EigenScalar
<
T
>::
From
(
input_min_t
);
input_min_t
.
mutable_data
<
T
>
({
1
},
dev_ctx
.
GetPlace
());
auto
input_max_scala
=
pten
::
EigenScalar
<
T
>::
From
(
input_max_t
);
auto
*
input_max_data
=
input_max_t
.
mutable_data
<
T
>
({
1
},
dev_ctx
.
GetPlace
());
auto
*
place
=
dev_ctx
.
eigen_device
();
auto
input_min_scala
=
pten
::
EigenScalar
<
T
>::
From
(
input_min_t
);
input_min_scala
.
device
(
*
place
)
=
input_x
.
minimum
();
auto
input_max_scala
=
pten
::
EigenScalar
<
T
>::
From
(
input_max_t
);
input_max_scala
.
device
(
*
place
)
=
input_x
.
maximum
();
auto
*
place
=
DenseTensor
input_min_cpu
,
input_max_cpu
;
dev_ctx
.
eigen_device
();
paddle
::
framework
::
TensorCopySync
(
input_min_scala
.
device
(
*
place
)
=
input_x
.
minimum
();
input_min_t
,
paddle
::
platform
::
CPUPlace
(),
&
input_min_cpu
);
input_max_scala
.
device
(
*
place
)
=
input_x
.
maximum
();
paddle
::
framework
::
TensorCopySync
(
input_max_t
,
paddle
::
platform
::
CPUPlace
(),
&
input_max_cpu
);
DenseTensor
input_min_cpu
,
input_max_cpu
;
paddle
::
framework
::
TensorCopySync
(
input_min_t
,
paddle
::
platform
::
CPUPlace
(),
output_min
=
input_min_cpu
.
data
<
T
>
()[
0
];
&
input_min_cpu
);
output_max
=
input_max_cpu
.
data
<
T
>
()[
0
];
paddle
::
framework
::
TensorCopySync
(
input_max_t
,
paddle
::
platform
::
CPUPlace
(),
}
&
input_max_cpu
);
if
(
output_min
==
output_max
)
{
output_min
=
output_min
-
1
;
output_min
=
input_min_cpu
.
data
<
T
>
()[
0
];
output_max
=
output_max
+
1
;
output_max
=
input_max_cpu
.
data
<
T
>
()[
0
];
}
}
if
(
output_min
==
output_max
)
{
output_min
=
output_min
-
1
;
output_max
=
output_max
+
1
;
}
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
(
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
(
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))
||
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isinf
(
static_cast
<
float
>
(
output_min
))
||
std
::
isnan
(
static_cast
<
float
>
(
output_max
))),
std
::
isnan
(
static_cast
<
float
>
(
output_max
))),
false
,
pten
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
false
,
PADDLE_ENFORCE_GE
(
pten
::
errors
::
OutOfRange
(
"range of min, max is not finite"
));
output_max
,
output_min
,
PADDLE_ENFORCE_GE
(
pten
::
errors
::
InvalidArgument
(
output_max
,
"max must be larger or equal to min. If min and max are both zero, "
output_min
,
"the minimum and maximum values of the data are used. "
pten
::
errors
::
InvalidArgument
(
"But received max is %d, min is %d"
,
"max must be larger or equal to min. If min and max are both zero, "
maxval
,
minval
));
"the minimum and maximum values of the data are used. "
"But received max is %d, min is %d"
,
auto
stream
=
maxval
,
dev_ctx
.
stream
();
minval
));
KernelHistogram
<
T
,
IndexType
><<<
GET_BLOCKS
(
input_numel
),
PADDLE_CUDA_NUM_THREADS
,
auto
stream
=
dev_ctx
.
stream
();
nbins
*
sizeof
(
int64_t
),
stream
>>>
(
KernelHistogram
<
T
,
IndexType
><<<
GET_BLOCKS
(
input_numel
),
input_data
,
input_numel
,
nbins
,
output_min
,
output_max
,
out_data
);
PADDLE_CUDA_NUM_THREADS
,
nbins
*
sizeof
(
int64_t
),
stream
>>>
(
input_data
,
input_numel
,
nbins
,
output_min
,
output_max
,
out_data
);
}
}
}
//namespace pten
}
// namespace pten
PT_REGISTER_KERNEL
(
histogram
,
PT_REGISTER_KERNEL
(
histogram
,
GPU
,
GPU
,
...
...
paddle/pten/ops/compat/histogram_sig.cc
浏览文件 @
b8f7edc3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/compat/op_utils.h"
namespace
pten
{
namespace
pten
{
KernelSignature
HistogramOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
KernelSignature
HistogramOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
return
KernelSignature
(
"histogram"
,
{
"X"
},
{
"bins"
,
"min"
,
"max"
},
{
"Out"
});
"histogram"
,
{
"X"
},
{
"bins"
,
"min"
,
"max"
},
{
"Out"
});
}
}
}
// namespace pten
}
// namespace pten
PT_REGISTER_ARG_MAPPING_FN
(
histogram
,
pten
::
HistogramOpArgumentMapping
);
PT_REGISTER_ARG_MAPPING_FN
(
histogram
,
pten
::
HistogramOpArgumentMapping
);
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录