未验证 提交 91bd5835 编写于 作者: Q qingqing01 提交者: GitHub

Fix fill_constant_batch_size_like_op when input is LoDTensor. (#10943)

上级 bf869e45
......@@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
out->mutable_data<T>(odims, ctx.GetPlace());
}
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");
......
......@@ -50,5 +50,27 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest):
self.check_output()
class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest):
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.inputs = {
'Input': (np.random.random((31, 28)).astype("float32"),
[[0, 9, 23, 31]])
}
self.attrs = {
'value': 3.5,
'shape': [-1, 16],
'input_dim_idx': 0,
'output_dim_idx': 0
}
out = np.random.random((3, 16)).astype("float32")
out.fill(3.5)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册