diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 038fcfcfee490508f9aa866a2f9819b184f1fbaf..c24f924313fb90d33b17f727260578271f67ae88 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -73,13 +73,19 @@ class ScaleOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { - if (ctx.HasInput("ScaleTensor")) { - return framework::KernelSignature("scale.host", {"X", "ScaleTensor"}, - {"bias", "bias_after_scale"}, {"Out"}); - } else { - return framework::KernelSignature( - "scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); + if (ctx.InputVar("X")->IsType() || + ctx.InputVar("X")->IsType()) { + if (ctx.HasInput("ScaleTensor")) { + return framework::KernelSignature("scale.host", {"X", "ScaleTensor"}, + {"bias", "bias_after_scale"}, + {"Out"}); + } else { + return framework::KernelSignature( + "scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); + } } + // TODO(chenweihang): support other cases after selected rows added + return framework::KernelSignature("scale.unregistered", {}, {}, {}); } }; diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index baedc2b095914ecc1c034c99e89e93d230aa981b..c1ce032f506127e495dfd3231471fdabe6dfa26b 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -109,9 +109,7 @@ class TestScaleOpSelectedRows(unittest.TestCase): assert (in_array * scale == result_array).all() assert in_height == out_height - # TODO(chenweihang): output rows and height cannot be shared into - # fluid output tensor - # assert in_rows == out_rows + assert in_rows == out_rows def test_scale_selected_rows(self): places = [core.CPUPlace()]