未验证 提交 ca7f1cd2 编写于 作者: C Chen Weihang 提交者: GitHub

fix test_scale_op skipped test (#37153)

上级 d7d22640
......@@ -73,13 +73,19 @@ class ScaleOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasInput("ScaleTensor")) {
return framework::KernelSignature("scale.host", {"X", "ScaleTensor"},
{"bias", "bias_after_scale"}, {"Out"});
} else {
return framework::KernelSignature(
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
if (ctx.InputVar("X")->IsType<framework::LoDTensor>() ||
ctx.InputVar("X")->IsType<framework::Tensor>()) {
if (ctx.HasInput("ScaleTensor")) {
return framework::KernelSignature("scale.host", {"X", "ScaleTensor"},
{"bias", "bias_after_scale"},
{"Out"});
} else {
return framework::KernelSignature(
"scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"});
}
}
// TODO(chenweihang): support other cases after selected rows added
return framework::KernelSignature("scale.unregistered", {}, {}, {});
}
};
......
......@@ -109,9 +109,7 @@ class TestScaleOpSelectedRows(unittest.TestCase):
assert (in_array * scale == result_array).all()
assert in_height == out_height
# TODO(chenweihang): output rows and height cannot be shared into
# fluid output tensor
# assert in_rows == out_rows
assert in_rows == out_rows
def test_scale_selected_rows(self):
places = [core.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册