From f014e301197cc7bd9e101cd3b478da63f92557f3 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Fri, 3 Sep 2021 16:57:15 +0800 Subject: [PATCH] [NPU] add int64_t kernels for YoloV3, test=develop (#35045) * [NPU] add int64 kernels, test=develop * update ci scripts to be able to trun WITH_ASCEND_INT64 on, test=develop --- CMakeLists.txt | 1 + cmake/configure.cmake | 4 + .../fluid/operators/fill_constant_op_npu.cc | 39 ++++---- paddle/fluid/operators/increment_op_npu.cc | 29 ++---- paddle/fluid/operators/meshgrid_op_npu.cc | 23 +++-- paddle/fluid/operators/range_op_npu.cc | 33 ++----- paddle/fluid/operators/scale_op_npu.cc | 36 ++----- paddle/fluid/operators/stack_op_npu.cc | 25 ++--- paddle/scripts/paddle_build.sh | 2 + .../npu/test_fill_constant_op_npu.py | 26 ++++- .../unittests/npu/test_increment_op_npu.py | 26 +++++ .../unittests/npu/test_meshgrid_op_npu.py | 10 ++ .../tests/unittests/npu/test_range_npu.py | 98 +++++++++++++++++++ .../tests/unittests/npu/test_stack_op_npu.py | 24 +++-- 14 files changed, 258 insertions(+), 118 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_range_npu.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f25e7d9dc..219f6fe20b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -224,6 +224,7 @@ option(WITH_UNITY_BUILD "Compile with UnityBuild mode" OFF) option(WITH_STRIP "Strip so files of Whl packages" OFF) option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF) option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF) +option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF) # PY_VERSION if(NOT PY_VERSION) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 3a7f269eaa..7f737cc189 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -90,6 +90,10 @@ if(WITH_ASCEND_CL) add_definitions(-DPADDLE_WITH_ASCEND_CL) endif() +if(WITH_ASCEND_INT64) + add_definitions(-DPADDLE_WITH_ASCEND_INT64) +endif() + if(WITH_XPU) message(STATUS "Compile with XPU!") add_definitions(-DPADDLE_WITH_XPU) diff --git a/paddle/fluid/operators/fill_constant_op_npu.cc b/paddle/fluid/operators/fill_constant_op_npu.cc index 2626e6d960..ae0148a9bf 100644 --- a/paddle/fluid/operators/fill_constant_op_npu.cc +++ b/paddle/fluid/operators/fill_constant_op_npu.cc @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include - #include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/utils.h" @@ -22,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class FillConstantNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -32,7 +29,6 @@ class FillConstantNPUKernel : public framework::OpKernel { auto float_value = ctx.Attr("value"); auto* out_var = ctx.Output("Out"); - auto place = ctx.GetPlace(); auto stream = ctx.template device_context() .stream(); @@ -63,25 +59,28 @@ class FillConstantNPUKernel : public framework::OpKernel { } auto shape = GetShape(ctx); - Tensor tensor_tmp(data_type); - tensor_tmp.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&tensor_tmp, value); + Tensor tensor_value(data_type); + tensor_value.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&tensor_value, value); + + out_var->mutable_data(shape, ctx.GetPlace()); - out_var->mutable_data(shape, place); - const auto& runner = NpuOpRunner("FillD", {tensor_tmp}, {*out_var}, - {{"dims", framework::vectorize(shape)}}); - runner.Run(stream); + NpuOpRunner runner; + runner.SetType("Fill") + .AddInput(framework::vectorize(shape)) + .AddInput(tensor_value) + .AddOutput(*out_var) + .Run(stream); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; - REGISTER_OP_NPU_KERNEL( - fill_constant, - ops::FillConstantNPUKernel, - ops::FillConstantNPUKernel, - ops::FillConstantNPUKernel, - ops::FillConstantNPUKernel); + fill_constant, paddle::operators::FillConstantNPUKernel, + paddle::operators::FillConstantNPUKernel, + paddle::operators::FillConstantNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + paddle::operators::FillConstantNPUKernel, +#endif + paddle::operators::FillConstantNPUKernel); diff --git a/paddle/fluid/operators/increment_op_npu.cc b/paddle/fluid/operators/increment_op_npu.cc index cdd82b55b7..85883325da 100644 --- a/paddle/fluid/operators/increment_op_npu.cc +++ b/paddle/fluid/operators/increment_op_npu.cc @@ -14,22 +14,11 @@ #include "paddle/fluid/operators/increment_op.h" #include "paddle/fluid/operators/npu_op_runner.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace framework { -class OpDesc; -class Variable; -} // namespace framework -namespace imperative { -class OpBase; -} // namespace imperative -} // namespace paddle namespace paddle { namespace operators { -template +template class IncrementalNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -56,13 +45,11 @@ class IncrementalNPUKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -namespace plat = paddle::platform; -namespace ops = paddle::operators; - REGISTER_OP_NPU_KERNEL( - increment, - ops::IncrementalNPUKernel, - ops::IncrementalNPUKernel, - ops::IncrementalNPUKernel, - ops::IncrementalNPUKernel) + increment, paddle::operators::IncrementalNPUKernel, + paddle::operators::IncrementalNPUKernel, + paddle::operators::IncrementalNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + paddle::operators::IncrementalNPUKernel, +#endif + paddle::operators::IncrementalNPUKernel) diff --git a/paddle/fluid/operators/meshgrid_op_npu.cc b/paddle/fluid/operators/meshgrid_op_npu.cc index a72c611a65..9605fa092f 100644 --- a/paddle/fluid/operators/meshgrid_op_npu.cc +++ b/paddle/fluid/operators/meshgrid_op_npu.cc @@ -18,7 +18,7 @@ limitations under the Licnse. */ namespace paddle { namespace operators { -template +template class MeshgridNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -65,9 +65,12 @@ class MeshgridNPUKernel : public framework::OpKernel { auto stream = context.template device_context() .stream(); - const auto& runner = NpuOpRunner("BroadcastToD", {reshape_ins_tensor}, - {*(outs[i])}, {{"shape", shape}}); - runner.Run(stream); + NpuOpRunner runner; + runner.SetType("BroadcastTo") + .AddInput(reshape_ins_tensor) + .AddInput(std::move(shape)) + .AddOutput(*(outs[i])) + .Run(stream); } } }; @@ -75,10 +78,10 @@ class MeshgridNPUKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -namespace ops = paddle::operators; -namespace plat = paddle::platform; - REGISTER_OP_NPU_KERNEL( - meshgrid, ops::MeshgridNPUKernel, - ops::MeshgridNPUKernel, - ops::MeshgridNPUKernel); + meshgrid, paddle::operators::MeshgridNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + paddle::operators::MeshgridNPUKernel, +#endif + paddle::operators::MeshgridNPUKernel, + paddle::operators::MeshgridNPUKernel); diff --git a/paddle/fluid/operators/range_op_npu.cc b/paddle/fluid/operators/range_op_npu.cc index a9a2effd2e..8379d45459 100644 --- a/paddle/fluid/operators/range_op_npu.cc +++ b/paddle/fluid/operators/range_op_npu.cc @@ -12,24 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_ASCEND_CL -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/range_op.h" -#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/operators/npu_op_runner.h" namespace paddle { namespace operators { -template +template class RangeNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -41,19 +30,19 @@ class RangeNPUKernel : public framework::OpKernel { framework::Tensor n; framework::TensorCopy( *start_t, platform::CPUPlace(), - context.template device_context(), &n); + context.template device_context(), &n); context.template device_context() .Wait(); T start = n.data()[0]; framework::TensorCopy( *end_t, platform::CPUPlace(), - context.template device_context(), &n); + context.template device_context(), &n); context.template device_context() .Wait(); T end = n.data()[0]; framework::TensorCopy( *step_t, platform::CPUPlace(), - context.template device_context(), &n); + context.template device_context(), &n); context.template device_context() .Wait(); T step = n.data()[0]; @@ -78,11 +67,9 @@ class RangeNPUKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -namespace ops = paddle::operators; - -REGISTER_OP_NPU_KERNEL( - range, ops::RangeNPUKernel, - ops::RangeNPUKernel, - ops::RangeNPUKernel) - +REGISTER_OP_NPU_KERNEL(range, paddle::operators::RangeNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + paddle::operators::RangeNPUKernel, #endif + paddle::operators::RangeNPUKernel, + paddle::operators::RangeNPUKernel) diff --git a/paddle/fluid/operators/scale_op_npu.cc b/paddle/fluid/operators/scale_op_npu.cc index 3892fcb758..2381719020 100644 --- a/paddle/fluid/operators/scale_op_npu.cc +++ b/paddle/fluid/operators/scale_op_npu.cc @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class ScaleNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -40,37 +40,21 @@ class ScaleNPUKernel : public framework::OpKernel { auto* scale_tensor = ctx.Input("ScaleTensor"); scale = static_cast(GetAttrFromTensor(scale_tensor)); } - if (bias_after_scale) { - out->mutable_data(ctx.GetPlace()); - const auto& runner = - NpuOpRunner("Power", {*x}, {*out}, - {{"power", power}, {"scale", scale}, {"shift", bias}}); - runner.Run(stream); - } else { - Tensor tmp_x(x->type()); - tmp_x.Resize(x->dims()); - tmp_x.mutable_data(ctx.GetPlace()); - const auto& runner_tmp = - NpuOpRunner("Adds", {*x}, {tmp_x}, {{"value", bias}}); - runner_tmp.Run(stream); - - out->mutable_data(ctx.GetPlace()); - float _bias = 0.0; - const auto& runner = - NpuOpRunner("Power", {tmp_x}, {*out}, - {{"power", power}, {"scale", scale}, {"shift", _bias}}); - runner.Run(stream); + if (!bias_after_scale) { + bias *= scale; } + out->mutable_data(ctx.GetPlace()); + const auto& runner = + NpuOpRunner("Power", {*x}, {*out}, + {{"power", power}, {"scale", scale}, {"shift", bias}}); + runner.Run(stream); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; - REGISTER_OP_NPU_KERNEL( - scale, ops::ScaleNPUKernel, - ops::ScaleNPUKernel); + scale, paddle::operators::ScaleNPUKernel, + paddle::operators::ScaleNPUKernel); diff --git a/paddle/fluid/operators/stack_op_npu.cc b/paddle/fluid/operators/stack_op_npu.cc index 3b685b3ab8..c6c54b7d81 100644 --- a/paddle/fluid/operators/stack_op_npu.cc +++ b/paddle/fluid/operators/stack_op_npu.cc @@ -20,7 +20,7 @@ namespace operators { using Tensor = framework::Tensor; -template +template class StackNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -49,7 +49,7 @@ class StackNPUKernel : public framework::OpKernel { } }; -template +template class StackGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -81,15 +81,18 @@ class StackGradNPUKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -namespace ops = paddle::operators; - REGISTER_OP_NPU_KERNEL( - stack, ops::StackNPUKernel, - ops::StackNPUKernel); + stack, paddle::operators::StackNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + paddle::operators::StackNPUKernel, +#endif + paddle::operators::StackNPUKernel, + paddle::operators::StackNPUKernel); REGISTER_OP_NPU_KERNEL( - stack_grad, - ops::StackGradNPUKernel, - ops::StackGradNPUKernel); + stack_grad, paddle::operators::StackNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + paddle::operators::StackNPUKernel, +#endif + paddle::operators::StackGradNPUKernel, + paddle::operators::StackGradNPUKernel); diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d9c9ff4dec..729cc799b8 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -228,6 +228,7 @@ function cmake_base() { -DWITH_ARM=${WITH_ARM:-OFF} -DWITH_ASCEND=${WITH_ASCEND:-OFF} -DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF} + -DWITH_ASCEND_INT64=${WITH_ASCEND_INT64:-OFF} -DWITH_STRIP=${WITH_STRIP:-ON} -DON_INFER=${ON_INFER:-OFF} ======================================== @@ -269,6 +270,7 @@ EOF -DWITH_ARM=${WITH_ARM:-OFF} \ -DWITH_ASCEND=${WITH_ASCEND:-OFF} \ -DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF} \ + -DWITH_ASCEND_INT64=${WITH_ASCEND_INT64:-OFF} \ -DWITH_STRIP=${WITH_STRIP:-ON} \ -DON_INFER=${ON_INFER:-OFF} \ -DWITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF};build_error=$? diff --git a/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py index c8d7f2f9dc..2ab1521380 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_fill_constant_op_npu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -72,6 +72,30 @@ class TestFillConstantInt(OpTest): self.check_output_with_place(self.place) +class TestFillConstantInt64(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "fill_constant" + + self.inputs = {} + self.attrs = { + 'shape': [123, 92], + 'value': 1, + 'dtype': core.VarDesc.VarType.INT64 + } + self.outputs = {'Out': np.full((123, 92), 1).astype(self.dtype)} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.int64 + + def test_check_output(self): + self.check_output_with_place(self.place) + + class TestFillConstantFP16(OpTest): def setUp(self): self.set_npu() diff --git a/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py index dfb9b26d64..626dbfc52a 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py @@ -81,6 +81,32 @@ class TestIncrementFP16(OpTest): self.check_output_with_place(self.place) +class TestIncrementINT64(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(NPUPlace) + self.op_type = "increment" + self.init_dtype() + + self.inputs = { + 'X': + OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), + } + self.pre_input_id = id(self.inputs['X']) + + self.attrs = {"Step": 1} + self.outputs = {'Out': np.array([2])} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.int64 + + def test_check_output(self): + self.check_output_with_place(self.place) + + class TestIncrementInplace(unittest.TestCase): def test_npu(self): main_prog = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/npu/test_meshgrid_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_meshgrid_op_npu.py index 216a6418ac..39802602bf 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_meshgrid_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_meshgrid_op_npu.py @@ -75,6 +75,16 @@ class TestMeshgridOpFP16(TestMeshgridOp): return "float16" +class TestMeshgridOpINT32(TestMeshgridOp): + def get_dtype(self): + return "int32" + + +class TestMeshgridOpINT64(TestMeshgridOp): + def get_dtype(self): + return "int64" + + class TestMeshgridOp2(TestMeshgridOp): def get_x_shape(self): return [100, 300] diff --git a/python/paddle/fluid/tests/unittests/npu/test_range_npu.py b/python/paddle/fluid/tests/unittests/npu/test_range_npu.py new file mode 100644 index 0000000000..c6700a19c5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_range_npu.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle + +paddle.enable_static() + + +class TestRangeOp(OpTest): + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def setUp(self): + self.set_npu() + self.op_type = "range" + self.init_config() + self.inputs = { + 'Start': np.array([self.case[0]]).astype(self.dtype), + 'End': np.array([self.case[1]]).astype(self.dtype), + 'Step': np.array([self.case[2]]).astype(self.dtype) + } + + self.outputs = { + 'Out': np.arange(self.case[0], self.case[1], + self.case[2]).astype(self.dtype) + } + + def init_config(self): + self.dtype = np.float32 + self.case = (0, 1, 0.2) + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestFloatRangeOpCase0(TestRangeOp): + def init_config(self): + self.dtype = np.float32 + self.case = (0, 5, 1) + + +class TestInt32RangeOpCase0(TestRangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (0, 5, 2) + + +class TestInt32RangeOpCase1(TestRangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (10, 1, -2) + + +class TestInt32RangeOpCase2(TestRangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (-1, -10, -2) + + +class TestInt64RangeOpCase0(TestRangeOp): + def init_config(self): + self.dtype = np.int64 + self.case = (0, 5, 2) + + +class TestInt64RangeOpCase1(TestRangeOp): + def init_config(self): + self.dtype = np.int64 + self.case = (10, 1, -2) + + +class TestInt64RangeOpCase2(TestRangeOp): + def init_config(self): + self.dtype = np.int64 + self.case = (-1, -10, -2) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py index bdfc7a03c6..af5648f8f3 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_stack_op_npu.py @@ -72,6 +72,8 @@ class TestStackOpBase(OpTest): self.check_output_with_place(self.place) def test_check_grad(self): + if self.dtype == np.int32 or self.dtype == np.int64: + return self.check_grad_with_place(self.place, self.get_x_names(), 'Y') @@ -105,6 +107,16 @@ class TestStackOp6(TestStackOpBase): self.axis = 3 +class TestStackOpINT32(TestStackOpBase): + def init_dtype(self): + self.dtype = np.int32 + + +class TestStackOpINT64(TestStackOpBase): + def init_dtype(self): + self.dtype = np.int64 + + class TestStackAPIWithLoDTensorArray(unittest.TestCase): """ Test stack api when the input(x) is a LoDTensorArray. @@ -180,15 +192,15 @@ class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase): class API_test(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program(), fluid.Program()): - data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float64') - data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float64') - data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float64') + data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float32') + data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float32') + data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float32') result_stack = paddle.stack([data1, data2, data3], axis=0) place = paddle.NPUPlace(0) exe = fluid.Executor(place) - input1 = np.random.random([1, 2]).astype('float64') - input2 = np.random.random([1, 2]).astype('float64') - input3 = np.random.random([1, 2]).astype('float64') + input1 = np.random.random([1, 2]).astype('float32') + input2 = np.random.random([1, 2]).astype('float32') + input3 = np.random.random([1, 2]).astype('float32') result, = exe.run( feed={"data1": input1, "data2": input2, -- GitLab