diff --git a/src/serialization/impl/extern_c_opr.cpp b/src/serialization/impl/extern_c_opr.cpp index 2a2da69e6182ffbfee080b8e9cfea725bb2ab410..e9b49597c6f2a0759538011c4a87de8fadc332bd 100644 --- a/src/serialization/impl/extern_c_opr.cpp +++ b/src/serialization/impl/extern_c_opr.cpp @@ -381,7 +381,6 @@ void ExternCOprRunner::check_param() { void ExternCOprRunner::scn_do_execute() { SmallVector c_inp(input().size()), c_out(output().size()); - SmallVector cpu_inp, cpu_out; check_param(); bool need_copy = false; diff --git a/src/serialization/include/megbrain/serialization/extern_c_opr_io.h b/src/serialization/include/megbrain/serialization/extern_c_opr_io.h index 9d4d4e2f50fe7aea4937a514c9534e50eb698592..f95729dd3244d4d041def7aaeaf3d45e6ddd633b 100644 --- a/src/serialization/include/megbrain/serialization/extern_c_opr_io.h +++ b/src/serialization/include/megbrain/serialization/extern_c_opr_io.h @@ -16,6 +16,9 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( //! store dynamic store param std::shared_ptr m_param; + //! HostTensorND holder for scn_do_execute + SmallVector cpu_inp, cpu_out; + void get_output_var_shape( const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const override; diff --git a/src/serialization/test/extern_c_opr.cpp b/src/serialization/test/extern_c_opr.cpp index a5be9582756aae38c67e2f8940881252ddc34635..191788e00d482a90e21fda599c1b8def82817212 100644 --- a/src/serialization/test/extern_c_opr.cpp +++ b/src/serialization/test/extern_c_opr.cpp @@ -445,6 +445,22 @@ TEST(TestExternCOpr, GPUCompute) { run_compute_test(CompNode::load("gpux"), MGB_DTYPE_FLOAT32); } +#if MGB_OPENCL +#include "megcore_opencl.h" + +#define REQUIRE_OPENCL() \ + do { \ + if (!CompNode::get_device_count(CompNode::DeviceType::OPENCL)) { \ + return; \ + } \ + } while (0) + +TEST(TestExternCOpr, OPENCLCompute) { + REQUIRE_OPENCL(); + run_compute_test(CompNode::load("openclx"), MGB_DTYPE_FLOAT32); +} +#endif + TEST(TestExternCOpr, CPUComputeMultiDtype) { run_compute_test(CompNode::load("cpux"), MGB_DTYPE_INT32); #if !MEGDNN_DISABLE_FLOAT16