提交 8ef6280c 编写于 作者: D dengkaipeng

Add operator double support. test=develop

上级 f115eb0d
......@@ -215,9 +215,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
ops::Yolov3LossGradMaker);
REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
REGISTER_OP_CPU_KERNEL(
yolov3_loss,
ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
yolov3_loss_grad,
ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
ops::Yolov3LossKernel<double>);
REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel<float>,
ops::Yolov3LossGradKernel<double>);
......@@ -323,7 +323,7 @@ static void AddAllGradToInputGrad(
}
}
template <typename DeviceContext, typename T>
template <typename T>
class Yolov3LossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -392,7 +392,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T>
class Yolov3LossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -195,7 +195,7 @@ class TestYolov3LossOp(OpTest):
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set("GTBox"),
no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.06)
def initTestCase(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册