未验证 提交 22301115 编写于 作者: A Aurelius84 提交者: GitHub

Remove constraint that last dimension is forced to be 1 in huber_loss op (#19562)

* Remove constraint that last dimension is forced to be 1 in huber_loss
test=develop

* add y[rank-1] == 1 when x_rank=y_rank test=develop

* modify into contain_unknown_dim test=develop
上级 5866a7a5
......@@ -25,27 +25,35 @@ class HuberLossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) must be initialized.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Input(Y) must be initialized.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1].");
if (ctx->IsRuntime() ||
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
PADDLE_ENFORCE_EQ(x_dims, y_dims, "Shape of X and Y should be same");
int rank = x_dims.size();
if (rank == y_dims.size()) {
PADDLE_ENFORCE_EQ(y_dims[rank - 1], 1U,
"The last dimension of Input(Y) should be equal to 1.");
} else {
PADDLE_ENFORCE_EQ(rank, y_dims.size() + 1,
"The rank of Input(X) should be equal to "
"the rank of Input(Y) plus 1.");
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(y_dims);
if (ctx->IsRuntime() || !contain_unknown_dim) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(y_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
}
ctx->SetOutputDim("Residual", x_dims);
ctx->SetOutputDim("Out", {x_dims[0], 1});
auto out_dims = y_dims;
ctx->SetOutputDim("Residual", out_dims);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", "Out");
}
};
......@@ -98,8 +106,8 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) should not be null.");
auto residual_dims = ctx->GetInputDim("Residual");
......
......@@ -30,19 +30,25 @@ def huber_loss_forward(val, delta):
class TestHuberLossOp(OpTest):
def setUp(self):
self.op_type = 'huber_loss'
samples_num = 64
delta = 1.0
self.inputs = {
'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'),
}
residual = self.inputs['Y'] - self.inputs['X']
self.samples_num = 64
self.delta = 1.0
self.init_input()
residual = self.inputs['Y'].reshape(
self.samples_num, 1) - self.inputs['X'].reshape(self.samples_num, 1)
loss = np.vectorize(huber_loss_forward)(residual,
delta).astype('float32')
self.attrs = {'delta': delta}
self.delta).astype('float32')
self.attrs = {'delta': self.delta}
self.outputs = {
'Residual': residual,
'Out': loss.reshape((samples_num, 1))
'Out': loss.reshape((self.samples_num, 1))
}
def init_input(self):
self.inputs = {
'X': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
}
def test_check_output(self):
......@@ -60,5 +66,14 @@ class TestHuberLossOp(OpTest):
['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
def TestHuberLossOp1(TestHuberLossOp):
def init_input(self):
self.inputs = {
'X': np.random.uniform(0, 1.,
(self.samples_num, 1)).astype('float32'),
'Y': np.random.uniform(0, 1., (self.samples_num)).astype('float32'),
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册