未验证 提交 f58c8db6 编写于 作者: A Aurelius84 提交者: GitHub

Require x.dims=label.dims in huber_loss (#20017)

* x.dims == y.dims test=develop

* refine comment
上级 cde73a7b
......@@ -32,23 +32,16 @@ class HuberLossOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int rank = x_dims.size();
if (rank == y_dims.size()) {
PADDLE_ENFORCE_EQ(y_dims[rank - 1], 1U,
"The last dimension of Input(Y) should be equal to 1.");
} else {
PADDLE_ENFORCE_EQ(rank, y_dims.size() + 1,
"The rank of Input(X) should be equal to "
"the rank of Input(Y) plus 1.");
}
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"The rank of Input(X) should be equal to "
"the rank of Input(Y).");
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(y_dims);
if (ctx->IsRuntime() || !contain_unknown_dim) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(y_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(
x_dims, y_dims,
"The Input(X) and Input(Label) should have the same shape.");
}
auto out_dims = y_dims;
......@@ -64,16 +57,16 @@ class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"The input value of huber loss op."
"X is a 2-D tensor with shape [batch_size, 1].");
"X is a N-D tensor with shape [N_1, N_2,..., N_n].");
AddInput("Y",
"The target value of huber loss op."
"Y is a 2-D tensor with shape [batch_size, 1].");
"Y is a N-D tensor with shape [N_1, N_2,..., N_n].");
AddOutput("Residual",
"Intermediate tensor to cache residual value between Y and X."
"The shape is same as Input(X) and will be reused in backward.")
.AsIntermediate();
AddOutput("Out",
"The output tensor with shape [batch_size, 1] "
"The output N-D tensor with shape [N_1, N_2,..., N_n] "
"which represents the huber loss.");
AddAttr<AttrType>("delta", "Hyper parameter in huber loss.");
AddComment(R"DOC(
......@@ -81,7 +74,7 @@ HuberLoss Operator.
Huber loss is a loss function used in robust regression. We define X as the
input value and Y as the target value. Huber loss can evaluate the fitness of
X to Y. Different from MSE loss, Huber loss is more robust for outliers. The
X to Y. Different from MSE loss, Huber loss is more robust for outliers. If the
shape of X and Y are [batch_size, 1]. The equation is:
$$
......
......@@ -30,27 +30,25 @@ def huber_loss_forward(val, delta):
class TestHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'huber_loss'
self.samples_num = 64
self.delta = 1.0
self.init_input()
residual = self.inputs['Y'].reshape(
self.samples_num, 1) - self.inputs['X'].reshape(self.samples_num, 1)
shape = self.set_shape()
residual = self.inputs['Y'] - self.inputs['X']
loss = np.vectorize(huber_loss_forward)(residual,
self.delta).astype('float32')
self.attrs = {'delta': self.delta}
self.outputs = {
'Residual': residual,
'Out': loss.reshape((self.samples_num, 1))
}
self.outputs = {'Residual': residual, 'Out': loss.reshape(shape)}
def init_input(self):
shape = self.set_shape()
self.inputs = {
'X': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'X': np.random.uniform(0, 1., shape).astype('float32'),
'Y': np.random.uniform(0, 1., shape).astype('float32'),
}
def set_shape(self):
return (64, 1)
def test_check_output(self):
self.check_output()
......@@ -67,12 +65,18 @@ class TestHuberLossOp(OpTest):
def TestHuberLossOp1(TestHuberLossOp):
def init_input(self):
self.inputs = {
'X': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1., (self.samples_num)).astype('float32'),
}
def set_shape(self):
return (64)
def TestHuberLossOp2(TestHuberLossOp):
def set_shape(self):
return (6, 6)
def TestHuberLossOp2(TestHuberLossOp):
def set_shape(self):
return (6, 6, 1)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册