未验证 提交 ca7f1cd2 编写于 作者: C Chen Weihang 提交者: GitHub

fix test_scale_op skipped test (#37153)

上级 d7d22640
...@@ -73,13 +73,19 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -73,13 +73,19 @@ class ScaleOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs( framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
if (ctx.HasInput("ScaleTensor")) { if (ctx.InputVar("X")->IsType<framework::LoDTensor>() ||
return framework::KernelSignature("scale.host", {"X", "ScaleTensor"}, ctx.InputVar("X")->IsType<framework::Tensor>()) {
{"bias", "bias_after_scale"}, {"Out"}); if (ctx.HasInput("ScaleTensor")) {
} else { return framework::KernelSignature("scale.host", {"X", "ScaleTensor"},
return framework::KernelSignature( {"bias", "bias_after_scale"},
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); {"Out"});
} else {
return framework::KernelSignature(
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
}
} }
// TODO(chenweihang): support other cases after selected rows added
return framework::KernelSignature("scale.unregistered", {}, {}, {});
} }
}; };
......
...@@ -109,9 +109,7 @@ class TestScaleOpSelectedRows(unittest.TestCase): ...@@ -109,9 +109,7 @@ class TestScaleOpSelectedRows(unittest.TestCase):
assert (in_array * scale == result_array).all() assert (in_array * scale == result_array).all()
assert in_height == out_height assert in_height == out_height
# TODO(chenweihang): output rows and height cannot be shared into assert in_rows == out_rows
# fluid output tensor
# assert in_rows == out_rows
def test_scale_selected_rows(self): def test_scale_selected_rows(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册