提交 e0f7bf4f 编写于 作者: T tink2123

polish the code

test=develop
上级 ffe81af0
...@@ -80,14 +80,11 @@ class AffineChannelOp : public framework::OperatorWithKernel { ...@@ -80,14 +80,11 @@ class AffineChannelOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL); PADDLE_ENFORCE_EQ(scale_dims.size(), 1UL);
PADDLE_ENFORCE_EQ(b_dims.size(), 1UL); PADDLE_ENFORCE_EQ(b_dims.size(), 1UL);
if (ctx->IsRuntime()) { if (ctx->IsRuntime() || scale_dims[0] > 0) {
PADDLE_ENFORCE_EQ(scale_dims[0], C); PADDLE_ENFORCE_EQ(scale_dims[0], C);
}
if (ctx->IsRuntime() || b_dims[0] > 0) {
PADDLE_ENFORCE_EQ(b_dims[0], C); PADDLE_ENFORCE_EQ(b_dims[0], C);
} else {
if (scale_dims[0] > 0 && b_dims[0] > 0) {
PADDLE_ENFORCE_EQ(scale_dims[0], C);
PADDLE_ENFORCE_EQ(b_dims[0], C);
}
} }
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
......
...@@ -69,7 +69,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -69,7 +69,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
if ((!ctx->IsRuntime()) && if ((!ctx->IsRuntime()) &&
(in_dims[i + 2] == -1 || filter_dims[i + 2] == -1)) { (in_dims[i + 2] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1); output_shape.push_back(-1);
} else { } else {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
......
...@@ -50,14 +50,12 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -50,14 +50,12 @@ class ROIPoolOp : public framework::OperatorWithKernel {
int pooled_width = ctx->Attrs().Get<int>("pooled_width"); int pooled_width = ctx->Attrs().Get<int>("pooled_width");
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale"); float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
if (ctx->IsRuntime()) { PADDLE_ENFORCE_GT(pooled_height, 0,
PADDLE_ENFORCE_GT(pooled_height, 0, "The pooled output height must greater than 0");
"The pooled output height must greater than 0"); PADDLE_ENFORCE_GT(pooled_width, 0,
PADDLE_ENFORCE_GT(pooled_width, 0, "The pooled output width must greater than 0");
"The pooled output width must greater than 0"); PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
PADDLE_ENFORCE_GT(spatial_scale, 0.0f, "The spatial scale must greater than 0");
"The spatial scale must greater than 0");
}
auto out_dims = input_dims; auto out_dims = input_dims;
out_dims[0] = rois_dims[0]; out_dims[0] = rois_dims[0];
......
...@@ -45,16 +45,10 @@ class RowConvOp : public framework::OperatorWithKernel { ...@@ -45,16 +45,10 @@ class RowConvOp : public framework::OperatorWithKernel {
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2."); PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
if (ctx->IsRuntime()) { if (ctx->IsRuntime() || (x_dims[1] > 0 && filter_dims[1] > 0)) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[1], filter_dims[1], x_dims[1], filter_dims[1],
"The 2nd dimension of Input(X) and Input(Filter) should be same."); "The 2nd dimension of Input(X) and Input(Filter) should be same.");
} else {
if (x_dims[1] > 0 && filter_dims[1] > 0) {
PADDLE_ENFORCE_EQ(
x_dims[1], filter_dims[1],
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
}
} }
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
......
...@@ -102,7 +102,7 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -102,7 +102,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
if (!ctx->IsRuntime() && in_x_dims[i + 2] == -1) { if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1); output_shape.push_back(-1);
} else { } else {
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i], output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册