未验证 提交 22301115 编写于 作者: A Aurelius84 提交者: GitHub

Remove constraint that last dimension is forced to be 1 in huber_loss op (#19562)

* Remove constraint that last dimension is forced to be 1 in huber_loss
test=develop

* add y[rank-1] == 1 when x_rank=y_rank test=develop

* modify into contain_unknown_dim test=develop
上级 5866a7a5
...@@ -25,27 +25,35 @@ class HuberLossOp : public framework::OperatorWithKernel { ...@@ -25,27 +25,35 @@ class HuberLossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized."); "Input(X) must be initialized.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Input(Y) must be initialized.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is " if (rank == y_dims.size()) {
"[batch_size, 1]."); PADDLE_ENFORCE_EQ(y_dims[rank - 1], 1U,
if (ctx->IsRuntime() || "The last dimension of Input(Y) should be equal to 1.");
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) { } else {
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same"); PADDLE_ENFORCE_EQ(rank, y_dims.size() + 1,
"The rank of Input(X) should be equal to "
"the rank of Input(Y) plus 1.");
} }
if (ctx->IsRuntime()) { bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
PADDLE_ENFORCE_EQ(x_dims[1], 1, framework::contain_unknown_dim(y_dims);
"Each row of Input(X) contains a real value, " if (ctx->IsRuntime() || !contain_unknown_dim) {
"so the 2nd dimension of Input(X) must be 1."); PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(y_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
} }
ctx->SetOutputDim("Residual", x_dims); auto out_dims = y_dims;
ctx->SetOutputDim("Out", {x_dims[0], 1}); ctx->SetOutputDim("Residual", out_dims);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
}; };
...@@ -98,8 +106,8 @@ class HuberLossGradOp : public framework::OperatorWithKernel { ...@@ -98,8 +106,8 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto residual_dims = ctx->GetInputDim("Residual"); auto residual_dims = ctx->GetInputDim("Residual");
......
...@@ -30,19 +30,25 @@ def huber_loss_forward(val, delta): ...@@ -30,19 +30,25 @@ def huber_loss_forward(val, delta):
class TestHuberLossOp(OpTest): class TestHuberLossOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'huber_loss' self.op_type = 'huber_loss'
samples_num = 64 self.samples_num = 64
delta = 1.0 self.delta = 1.0
self.inputs = { self.init_input()
'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), residual = self.inputs['Y'].reshape(
'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), self.samples_num, 1) - self.inputs['X'].reshape(self.samples_num, 1)
}
residual = self.inputs['Y'] - self.inputs['X']
loss = np.vectorize(huber_loss_forward)(residual, loss = np.vectorize(huber_loss_forward)(residual,
delta).astype('float32') self.delta).astype('float32')
self.attrs = {'delta': delta} self.attrs = {'delta': self.delta}
self.outputs = { self.outputs = {
'Residual': residual, 'Residual': residual,
'Out': loss.reshape((samples_num, 1)) 'Out': loss.reshape((self.samples_num, 1))
}
def init_input(self):
self.inputs = {
'X': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
} }
def test_check_output(self): def test_check_output(self):
...@@ -60,5 +66,14 @@ class TestHuberLossOp(OpTest): ...@@ -60,5 +66,14 @@ class TestHuberLossOp(OpTest):
['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual')) ['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
def TestHuberLossOp1(TestHuberLossOp):
def init_input(self):
self.inputs = {
'X': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1., (self.samples_num)).astype('float32'),
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册