提交 649bcd5f 编写于 作者: S silingtong123 提交者: Tao Luo

Modify the style of function names (#20071)

上级 62573d89
......@@ -33,9 +33,9 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ShapeTensor")) {
auto *shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
new_shape = get_new_data_from_shape_tensor(shape_tensor);
new_shape = GetNewDataFromShapeTensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = get_new_shape_from_shape_tensorlist(list_new_shape_tensor);
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
}
}
......@@ -169,14 +169,14 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("ShapeTensor",
"(Tensor<int64_t>, optional). If provided, uniform_ranodom "
"according to "
"this given shape. That is to say it has a higher priority than "
"this given shape. It means that it has a higher priority than "
"the shape attribute, while the shape attribute still should be "
"set correctly to gurantee shape inference in compile time.")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int64_t>>, optional). If provided, uniform_random "
"will use this"
"The shape of the tensor in vector MUST BE [1]"
"use this."
"The shape of the tensor in vector MUST BE [1],"
"it has the highest priority compare with Input(Shape) and "
"attr(shape).")
.AsDuplicable()
......
......@@ -64,9 +64,9 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) {
if (context.HasInput("ShapeTensor")) {
auto* shape_tensor = context.Input<framework::Tensor>("ShapeTensor");
new_shape = get_new_data_from_shape_tensor(shape_tensor);
new_shape = GetNewDataFromShapeTensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = get_new_shape_from_shape_tensorlist(list_new_shape_tensor);
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
}
}
......
......@@ -22,7 +22,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int64_t> get_new_data_from_shape_tensor(
inline std::vector<int64_t> GetNewDataFromShapeTensor(
const Tensor *new_data_tensor) {
auto *new_data = new_data_tensor->data<int64_t>();
if (platform::is_gpu_place(new_data_tensor->place())) {
......@@ -35,7 +35,7 @@ inline std::vector<int64_t> get_new_data_from_shape_tensor(
return vec_new_data;
}
inline std::vector<int64_t> get_new_shape_from_shape_tensorlist(
inline std::vector<int64_t> GetNewDataFromShapeTensorList(
const std::vector<const Tensor *> &list_new_shape_tensor) {
std::vector<int64_t> vec_new_shape;
vec_new_shape.reserve(list_new_shape_tensor.size());
......@@ -46,10 +46,9 @@ inline std::vector<int64_t> get_new_shape_from_shape_tensorlist(
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int64_t>(*temp.data<int64_t>()));
vec_new_shape.push_back(*temp.data<int64_t>());
} else {
vec_new_shape.push_back(static_cast<int64_t>(*tensor->data<int64_t>()));
vec_new_shape.push_back(*tensor->data<int64_t>());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册