提交 5663fbfb 编写于 作者: X xuezhong

fix infershape bug

test=develop
上级 c96ee47d
...@@ -152,12 +152,19 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -152,12 +152,19 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
auto transition_dims = ctx->GetInputDim("Transition"); auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(transition_dims.size(), 2, PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
"The Input(Transition) should be a 2-D tensor."); "The Input(Transition) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ( bool check = true;
transition_dims[0] - 2, transition_dims[1], if ((!ctx->IsRuntime()) &&
"An invalid dimension for the Input(Transition), which should " (transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
"be a 2-D tensor with shape [(D + 2) x D]."); check = false;
PADDLE_ENFORCE_EQ( }
emission_dims[1], transition_dims[1], if (check) {
PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) " "The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number."); "should be equal to the tag number.");
...@@ -165,8 +172,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -165,8 +172,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd " "The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1."); "dimensions fixed to 1.");
PADDLE_ENFORCE_EQ( PADDLE_INFERSHAPE_ENFORCE_EQ(
emission_dims[0], label_dims[0], ctx, emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) " "The height of Input(Emission) and the height of Input(Label) "
"should be the same."); "should be the same.");
...@@ -211,12 +218,19 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -211,12 +218,19 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
auto transition_exps_dims = ctx->GetInputDim("TransitionExps"); auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2, PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
"The Input(TransitionExps) should be a 2-D tensor."); "The Input(TransitionExps) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ( bool check = true;
transition_exps_dims[0] - 2, transition_exps_dims[1], if ((!ctx->IsRuntime()) &&
"An invalid dimension for the Input(TransitionExps), which should " (transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
"be a 2-D tensor with shape [(D + 2) x D]."); check = false;
PADDLE_ENFORCE_EQ( }
emission_exps_dims[1], transition_exps_dims[1], if (check) {
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the " "The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number."); "Input(TransitionExps) should be equal to the tag number.");
...@@ -224,8 +238,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -224,8 +238,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL, PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd " "The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1."); "dimensions fixed to 1.");
PADDLE_ENFORCE_EQ( PADDLE_INFERSHAPE_ENFORCE_EQ(
emission_exps_dims[0], label_dims[0], ctx, emission_exps_dims[0], label_dims[0],
"The height of Input(EmissionExps) and the height of Input(Label) " "The height of Input(EmissionExps) and the height of Input(Label) "
"should be the same."); "should be the same.");
......
...@@ -41,10 +41,11 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -41,10 +41,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
// it's the output of topk. // it's the output of topk.
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2."); PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1"); PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, label_dim[1], 1,
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0], "label's second dimension must be 1");
"the inference tensor's num_rows must be" PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, inference_dim[0], label_dim[0],
" the same as label."); "the inference tensor's num_rows must be"
" the same as label.");
ctx->SetOutputDim("Accuracy", {1}); ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1}); ctx->SetOutputDim("Correct", {1});
......
...@@ -32,8 +32,8 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -32,8 +32,8 @@ class AucOp : public framework::OperatorWithKernel {
auto predict_height = ctx->GetInputDim("Predict")[0]; auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0]; auto label_height = ctx->GetInputDim("Label")[0];
PADDLE_ENFORCE_EQ(predict_height, label_height, PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_height, label_height,
"Out and Label should have same height."); "Out and Label should have same height.");
int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1; int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps"); int slide_steps = ctx->Attrs().Get<int>("slide_steps");
......
...@@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sample_logits_op.h" #include "paddle/fluid/operators/sample_logits_op.h"
#include "paddle/fluid/operators/math/sample_prob.h" #include "paddle/fluid/operators/math/sample_prob.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -132,7 +133,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel { ...@@ -132,7 +133,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
"The labels should be a 2-D tensor."); "The labels should be a 2-D tensor.");
const int num_samples = ctx->Attrs().Get<int>("num_samples"); const int num_samples = ctx->Attrs().Get<int>("num_samples");
const int num_sampled_classes = labels_dims[1] + num_samples; int num_sampled_classes = labels_dims[1] + num_samples;
if ((!ctx->IsRuntime()) && labels_dims[1] <= 0) {
num_sampled_classes = -1;
}
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes}); ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/smooth_l1_loss_op.h" #include "paddle/fluid/operators/smooth_l1_loss_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,7 +29,14 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { ...@@ -27,7 +29,14 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims); bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims, y_dims);
}
PADDLE_ENFORCE_GE(x_dims.size(), 2, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The tensor rank of Input(X) should not be less than 2."); "The tensor rank of Input(X) should not be less than 2.");
if (ctx->HasInput("InsideWeight")) { if (ctx->HasInput("InsideWeight")) {
...@@ -110,11 +119,11 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { ...@@ -110,11 +119,11 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(out_dims.size(), 2, PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2."); "The tensor rank of Input(Out@Grad) should be 2.");
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0], PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be " "The 1st dimension of Input(Out@Grad) must be "
"same as input."); "same as input.");
PADDLE_ENFORCE_EQ(out_dims[1], 1, PADDLE_INFERSHAPE_ENFORCE_EQ(
"The 2nd dimension of Input(Out@Grad) must be 1."); ctx, out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1.");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
......
...@@ -45,13 +45,26 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -45,13 +45,26 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
int rank = framework::arity(x_dims); int rank = framework::arity(x_dims);
PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2."); PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2.");
PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0], product(y_dims) / y_dims[0], bool check = true;
"Product of dimensions expcet the first dimension of " if ((!ctx->IsRuntime()) &&
"input and target must be equal."); (framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0], check = false;
"First dimension of target must be equal to input " }
"or to 1."); if (check) {
PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0],
product(y_dims) / y_dims[0],
"Product of dimensions expcet the first dimension of "
"input and target must be equal.");
}
check = true;
if ((!ctx->IsRuntime()) && (y_dims[0] <= 0 || x_dims[0] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
"First dimension of target must be equal to input "
"or to 1.");
}
ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]}); ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]});
ctx->SetOutputDim("Out", {x_dims[0], 1}); ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -124,12 +137,12 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -124,12 +137,12 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], x_dims[0],
"First dimension of output gradient and " "First dimension of output gradient and "
"input value must be equal."); "input value must be equal.");
PADDLE_ENFORCE_EQ(out_dims[1], 1, PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[1], 1,
"Second dimension of output gradient " "Second dimension of output gradient "
"must be 1."); "must be 1.");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims); if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims);
......
...@@ -356,5 +356,46 @@ using CommonType2 = typename std::add_lvalue_reference< ...@@ -356,5 +356,46 @@ using CommonType2 = typename std::add_lvalue_reference<
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
#define __PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL1, __VAL2, __CMP, \
__INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
if (!__CTX->IsRuntime()) { \
if (__val1 == -1 || __val2 == -1) { \
break; \
} \
} \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
" %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \
#__VAL1, #__VAL2, #__VAL1, \
::paddle::string::to_string(__val1), #__VAL2, \
::paddle::string::to_string(__val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)
#define PADDLE_INFERSHAPE_ENFORCE_EQ(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_NE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_GT(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_GE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_LT(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_LE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <=, >, __VA_ARGS__)
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册