提交 75d0c790 编写于 作者: Y Yu Yang 提交者: GitHub

Change Name convention of operator attributes (#4807)

* Change dataType to data_type

Follow PEP8

* Change name_convention to fit PEP8
上级 790b9ce4
......@@ -433,7 +433,7 @@ ParamGradInfoMap AppendBackward(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)},
{"dataType", framework::DataType::FP32}}));
{"data_type", framework::DataType::FP32}}));
all_ops.push_back(std::move(fill_one_op));
size_t forward_op_num = all_ops.size();
size_t forward_block_num = program_desc.Size();
......
......@@ -34,13 +34,13 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
"The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
if (ctx->Attrs().Get<bool>("softLabel")) {
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of "
"If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"If Attr(softLabel) == false, the 2nd dimension of "
"If Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
......@@ -82,13 +82,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"be equal.");
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
"The 2nd dimension of Input(Y@Grad) should be 1.");
if (ctx->Attrs().Get<bool>("softLabel")) {
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of "
"When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(label_dims[1], 1,
"When Attr(softLabel) == false, the 2nd dimension of "
"When Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
......@@ -115,15 +115,15 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Label",
"(Tensor, default Tensor<int>), the ground truth which is "
"a 2-D tensor. "
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
"When soft_label is set to false, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When softLabel is set to true, `Label` is a Tensor<float/double> "
"When soft_label is set to true, `Label` is a Tensor<float/double> "
"with shape [N x K].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor "
"with shape [N x 1]. The cross entropy loss.");
AddAttr<bool>(
"softLabel",
"soft_label",
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
......@@ -133,12 +133,12 @@ CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
softLabel = false, Label[i, 0] indicates the class index for sample i:
soft_label = false, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
softLabel = true, Label[i, j] indicates the soft label of class j
soft_label = true, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
......
......@@ -56,7 +56,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
ctx.device_context(), y, x, label, ctx.Attr<bool>("softLabel"));
ctx.device_context(), y, x, label, ctx.Attr<bool>("soft_label"));
}
};
......@@ -83,7 +83,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
if (ctx.Attr<bool>("softLabel")) {
if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
......
......@@ -38,7 +38,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
ctx.device_context(), y, x, labels, ctx.Attr<bool>("softLabel"));
ctx.device_context(), y, x, labels, ctx.Attr<bool>("soft_label"));
}
};
......@@ -55,7 +55,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
int class_num = x->dims()[1];
if (ctx.Attr<bool>("softLabel")) {
if (ctx.Attr<bool>("soft_label")) {
auto x_mat = EigenMatrix<T>::From(*x);
auto dy_mat = EigenMatrix<T>::From(*dy);
auto lbl_mat = EigenMatrix<T>::From(*label);
......
......@@ -35,7 +35,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
framework::DataType IndicateDataType(
const framework::ExecutionContext &ctx) const override {
return static_cast<framework::DataType>(ctx.Attr<int>("dataType"));
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
}
};
......@@ -44,7 +44,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
FillConstantOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dataType",
AddAttr<int>("data_type",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
......
......@@ -11,7 +11,7 @@ When defining an operator in Paddle, a corresponding [OpProtoMaker](https://gith
- If an operator's Input/Output are tensors in math, not match to any meaningful words, input name should starts from `X`. e.g. `X`, `Y`, and output name should starts from `Out`. e.g. `Out`. This rule intends making operators which have few inputs/outputs unified.
- Attribute.
- Attribute name follows the **camelCase**. e.g. `x`, `y`, `axis`, `rowwiseMatrix`. Also, attribute name prefers to meaningful English words.
- Attribute name follows the **snake_case**. e.g. `x`, `y`, `axis`, `rowwise_matrix`. Also, attribute name prefers to meaningful English words.
- Comments.
- Input/Output/Attr comment follow the format of **(type,default value) usage**, corresponding to which type it can be and how it will be used in the operator. e.g. Attribute in Accumulator`"gamma" `,`(float, default 1.0) Accumulation multiplier`.
......
......@@ -29,7 +29,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
auto in_x_dims = ctx->GetInputDim("X");
std::string pooling_type = ctx->Attrs().Get<std::string>("poolingType");
std::string pooling_type = ctx->Attrs().Get<std::string>("pooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
......@@ -37,7 +37,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor.");
if (ctx->Attrs().Get<bool>("globalPooling")) {
if (ctx->Attrs().Get<bool>("global_pooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
......@@ -80,23 +80,23 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
AddAttr<std::string>("pooling_type",
"Pooling_type of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"The pooling window size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"If global_pooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"global_pooling",
"Whether to use the global_pooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"The strides(height, width) of pooling window."
......@@ -146,7 +146,7 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"the number of channels, D, H and W is the depth, height and "
"width of feature.");
AddAttr<std::string>("poolingType",
AddAttr<std::string>("pooling_type",
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
......@@ -154,15 +154,15 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
AddAttr<std::vector<int>>(
"ksize",
"The pooling window size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"If global_pooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"global_pooling",
"Whether to use the global_pooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(depth, height, width) of pooling operator."
......
......@@ -59,11 +59,11 @@ class PoolKernel : public framework::OpKernel<T> {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
std::string pooling_type = context.Attr<std::string>("poolingType");
std::string pooling_type = context.Attr<std::string>("pooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
......@@ -119,12 +119,12 @@ class PoolGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = context.Attr<std::string>("poolingType");
std::string pooling_type = context.Attr<std::string>("pooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
......
......@@ -45,7 +45,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor.");
if (ctx->Attrs().Get<bool>("globalPooling")) {
if (ctx->Attrs().Get<bool>("global_pooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
......@@ -108,15 +108,15 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>(
"ksize",
"The pooling window size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"If global_pooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"global_pooling",
"Whether to use the global_pooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"The strides(height, width) of pooling window."
......@@ -179,15 +179,15 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>(
"ksize",
"The pooling window size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"If global_pooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"global_pooling",
"Whether to use the global_pooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
......
......@@ -35,7 +35,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
......@@ -70,7 +70,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
......
......@@ -46,7 +46,7 @@ class SoftmaxWithCrossEntropyOpMaker
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
"entropy loss with shape [N x 1].");
AddAttr<bool>(
"softLabel",
"soft_label",
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
......@@ -100,13 +100,13 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
if (ctx->Attrs().Get<bool>("softLabel")) {
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(softLabel) == true, the 2nd dimension of "
"If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(softLabel) == false, the 2nd dimension of "
"If Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
......@@ -142,13 +142,13 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
if (ctx->Attrs().Get<bool>("softLabel")) {
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(softLabel) == true, the 2nd dimension of "
"When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(softLabel) == false, the 2nd dimension of "
"When Attr(soft_label) == false, the 2nd dimension of "
"Input(Label) should be 1.");
}
......
......@@ -70,7 +70,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
logits, softmax);
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
context.Attr<bool>("soft_label"));
}
};
......@@ -93,7 +93,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
if (context.Attr<bool>("softLabel")) {
if (context.Attr<bool>("soft_label")) {
const T* label_data = labels->data<T>();
SoftCrossEntropyGradientKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
......
......@@ -44,7 +44,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
logits, softmax);
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
context.Attr<bool>("soft_label"));
}
};
......@@ -60,7 +60,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
const int class_num = logit_grad->dims()[1];
if (context.Attr<bool>("softLabel")) {
if (context.Attr<bool>("soft_label")) {
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
auto lbl_mat = EigenMatrix<T>::From(*labels);
......
......@@ -49,7 +49,7 @@ class TestCrossEntropyOp2(OpTest):
self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy}
self.attrs = {"softLabel": True}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
......@@ -82,7 +82,7 @@ class TestCrossEntropyOp3(OpTest):
self.inputs = {"X": X, "Label": label.astype(np.float32)}
self.outputs = {"Y": cross_entropy}
self.attrs = {"softLabel": True}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
......
......@@ -56,8 +56,8 @@ class TestPool2d_Op(OpTest):
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'poolingType': self.pool_type,
'globalPooling': self.global_pool,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
}
self.outputs = {'Out': output}
......
......@@ -64,8 +64,8 @@ class TestPool3d_Op(OpTest):
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'poolingType': self.pool_type,
'globalPooling': self.global_pool,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
}
self.outputs = {'Out': output}
......
......@@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'globalPooling': self.global_pool,
'global_pooling': self.global_pool,
}
self.inputs = {'X': input}
......
......@@ -57,7 +57,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
self.attrs = {"softLabel": True}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册