提交 cdda0cf3 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #3913 from pkuyym/fix-3789

Complete smooth_l1_loss_op.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/smooth_l1_loss_op.h"
namespace paddle {
namespace operators {
class SmoothL1LossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Y must be initialized.");
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
"The shape of X and Y must be the same.");
PADDLE_ENFORCE_GE(x->dims().size(), 2,
"The tensor rank of X must be at least 2.");
auto* inside_weight = ctx.Input<framework::Tensor>("InsideWeight");
if (inside_weight) {
auto* outside_weight = ctx.Input<framework::Tensor>("OutsideWeight");
PADDLE_ENFORCE_NOT_NULL(outside_weight,
"If weights are provided, must specify both "
"inside and outside weights.");
PADDLE_ENFORCE_EQ(inside_weight->dims(), x->dims(),
"The shape of InsideWeight must be same as X.");
PADDLE_ENFORCE_EQ(outside_weight->dims(), x->dims(),
"The shape of OutsideWeight must be same as X.");
}
auto* diff = ctx.Output<framework::LoDTensor>("Diff");
auto* out = ctx.Output<framework::LoDTensor>("Out");
diff->Resize(x->dims());
// loss is a two-rank tensor
out->Resize({x->dims()[0], 1});
}
};
template <typename AttrType>
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SmoothL1LossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input tensor of smooth l1 loss op."
"The rank should be greater or equal to 2 with shape "
"[batch_size, value_dim1, value_dim2, ..., value_dimN]");
AddInput("Y",
"The target tensor of smooth l1 loss op "
"with the same shape as X.");
AddInput("InsideWeight",
"Optional input tensor of smooth l1 loss op with the same shape "
"as X. If provided, the result of (X - Y) will be multiplied "
"by this tensor element by element.");
AddInput("OutsideWeight",
"Optinal input of smooth l1 loss op with the same shape as X."
"If provided, the output smooth l1 loss will be multiplied by "
"this tensor element by element.");
AddOutput("Diff", "Intermediate variable to cache InsideWeight*(X-Y).")
.AsIntermediate();
AddOutput("Out", "Smooth l1 loss.");
AddAttr<AttrType>("sigma",
"Hyper parameter of smooth l1 loss op."
"A float scalar with default value 3.0.")
.SetDefault(3.0);
AddComment(R"DOC(
Compute smooth l1 loss for input and target. The operator take the 1st
dimension of input as batch size. For each instance, it will compute
smooth l1 loss element by element first and sum all losses to one value.
So the output shape is [batch_size, 1].
The equation is:
loss = 0.5 * (sigma * (x-y))^2 if abs(x - y) < 1 / sigma^2
abs(x - y) - 0.5 / sigma^2 otherwise
)DOC");
}
};
class SmoothL1LossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
auto in_dims = ctx.Input<framework::Tensor>("X")->dims();
auto out_dims =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->dims();
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2.");
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be "
"same as input.");
PADDLE_ENFORCE_EQ(out_dims[1], 1,
"The 2nd dimension of Input(Out@Grad) must be 1.");
if (x_grad) x_grad->Resize(in_dims);
if (y_grad) y_grad->Resize(in_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp,
ops::SmoothL1LossOpMaker<float>, smooth_l1_loss_grad,
ops::SmoothL1LossGradOp);
REGISTER_OP_CPU_KERNEL(
smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
smooth_l1_loss_grad,
ops::SmoothL1LossGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/smooth_l1_loss_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
smooth_l1_loss_grad,
ops::SmoothL1LossGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct SmoothL1LossForward {
HOSTDEVICE SmoothL1LossForward(const T& sigma2) : sigma2(sigma2) {}
HOSTDEVICE T operator()(const T& val) const {
T abs_val = std::abs(val);
if (abs_val < 1.0 / sigma2) {
return 0.5 * val * val * sigma2;
} else {
return abs_val - 0.5 / sigma2;
}
}
T sigma2;
};
template <typename Place, typename T, typename AttrType = T>
class SmoothL1LossKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y");
auto* in2 = context.Input<Tensor>("InsideWeight");
auto* in3 = context.Input<Tensor>("OutsideWeight");
auto* out0 = context.Output<Tensor>("Diff");
auto* out1 = context.Output<Tensor>("Out");
out0->mutable_data<T>(context.GetPlace());
out1->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
auto sigma = static_cast<T>(context.Attr<AttrType>("sigma"));
T sigma2 = sigma * sigma;
bool has_weight = (in2 != nullptr) && (in3 != nullptr);
auto x = EigenVector<T>::Flatten(*in0);
auto y = EigenVector<T>::Flatten(*in1);
auto diff = EigenVector<T>::Flatten(*out0);
diff.device(place) = x - y;
// multiply inside weight
if (has_weight) {
auto inside_weight = EigenVector<T>::Flatten(*in2);
// cache diff, reused in bp
diff.device(place) = diff * inside_weight;
}
auto in_counts = in0->numel();
Tensor ptensor_errors;
ptensor_errors.mutable_data<T>({static_cast<int>(in_counts)},
context.GetPlace());
auto errors = EigenVector<T>::Flatten(ptensor_errors);
// apply smooth l1 forward
errors.device(place) = diff.unaryExpr(SmoothL1LossForward<T>(sigma2));
// multiply outside weight
if (has_weight) {
auto outside_weight = EigenVector<T>::Flatten(*in3);
errors.device(place) = errors * outside_weight;
}
auto loss = EigenVector<T>::Flatten(*out1);
// first dimension of 'X' is the number of samples
auto mat_dims =
framework::make_ddim({static_cast<int>(in0->dims()[0]),
static_cast<int>(in_counts / in0->dims()[0])});
auto errors_mat_view = EigenMatrix<T>::From(ptensor_errors, mat_dims);
loss.device(place) = errors_mat_view.sum(Eigen::array<int, 1>({{1}}));
}
};
template <typename T>
struct SmoothL1LossBackward {
HOSTDEVICE SmoothL1LossBackward(const T& sigma2) : sigma2(sigma2) {}
HOSTDEVICE T operator()(const T& val) const {
T abs_val = std::abs(val);
if (abs_val < 1.0 / sigma2) {
return sigma2 * val;
} else {
return (0 < val) - (val < 0);
}
}
T sigma2;
};
template <typename Place, typename T, typename AttrType = T>
class SmoothL1LossGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("InsideWeight");
auto* in1 = context.Input<Tensor>("OutsideWeight");
auto* in2 = context.Input<Tensor>("Diff");
auto* og = context.Input<Tensor>(framework::GradVarName("Out"));
auto sigma = static_cast<T>(context.Attr<AttrType>("sigma"));
T sigma2 = sigma * sigma;
bool has_weight = (in0 != nullptr) && (in1 != nullptr);
auto place = context.GetEigenDevice<Place>();
auto in_dims = in2->dims();
auto counts = in2->numel();
auto cols = counts / in_dims[0];
auto mat_dims = framework::make_ddim(
{static_cast<int>(in_dims[0]), static_cast<int>(cols)});
Tensor ptensor_diff;
ptensor_diff.mutable_data<T>({static_cast<int>(counts)},
context.GetPlace());
auto diff = EigenVector<T>::Flatten(ptensor_diff);
// apply smooth l1 backwoard
diff.device(place) = EigenVector<T>::Flatten(*in2).unaryExpr(
SmoothL1LossBackward<T>(sigma2));
// compute weights
Tensor ptensor_weights;
ptensor_weights.mutable_data<T>(mat_dims, context.GetPlace());
auto weights = EigenMatrix<T>::From(ptensor_weights);
// initialize to 1.0
weights.device(place) = weights.constant(static_cast<T>(1.0));
if (has_weight) {
auto inside_weight = EigenMatrix<T>::From(*in0, mat_dims);
auto outside_weight = EigenMatrix<T>::From(*in1, mat_dims);
weights.device(place) = inside_weight * outside_weight;
}
// compute gradients
auto out_grad = EigenMatrix<T>::From(*og);
auto diff_mat_view = EigenMatrix<T>::From(ptensor_diff, mat_dims);
auto gradients = out_grad.broadcast(
Eigen::array<int, 2>({{1, static_cast<int>(cols)}})) *
weights * diff_mat_view;
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
if (out0) {
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenMatrix<T>::From(*out0, mat_dims);
x_grad.device(place) = gradients;
}
if (out1) {
out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenMatrix<T>::From(*out1, mat_dims);
y_grad.device(place) = -1 * gradients;
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def smooth_l1_loss_forward(val, sigma2):
abs_val = abs(val)
if abs_val < 1.0 / sigma2:
return 0.5 * val * val * sigma2
else:
return abs_val - 0.5 / sigma2
class TestSmoothL1LossOp1(OpTest):
def setUp(self):
self.op_type = "smooth_l1_loss"
dims = (5, 10)
self.inputs = {
'X': np.random.random(dims).astype("float32"),
'Y': np.random.random(dims).astype("float32")
}
sigma = 3.0
self.attrs = {'sigma': sigma}
sigma2 = sigma * sigma
diff = self.inputs['X'] - self.inputs['Y']
loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2).sum(1)
loss = loss.reshape((dims[0], 1))
self.outputs = {'Diff': diff, 'Out': loss}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.03, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.03, no_grad_set=set('Y'))
class TestSmoothL1LossOp2(OpTest):
def setUp(self):
self.op_type = "smooth_l1_loss"
dims = (5, 10)
self.inputs = {
'X': np.random.random(dims).astype("float32"),
'Y': np.random.random(dims).astype("float32"),
'InsideWeight': np.random.random(dims).astype("float32"),
'OutsideWeight': np.random.random(dims).astype("float32")
}
sigma = 3.0
self.attrs = {'sigma': sigma}
sigma2 = sigma * sigma
diff = self.inputs['X'] - self.inputs['Y']
diff = diff * self.inputs['InsideWeight']
loss = np.vectorize(smooth_l1_loss_forward)(diff, sigma2)
loss = loss * self.inputs['OutsideWeight']
loss = loss.sum(1).reshape((dims[0], 1))
self.outputs = {'Diff': diff, 'Out': loss}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.03)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
max_relative_error=0.03,
no_grad_set=set(['X', 'InsideWeight', 'OutsideWeight']))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
max_relative_error=0.03,
no_grad_set=set(['Y', 'InsideWeight', 'OutsideWeight']))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册