未验证 提交 68c3e2cb 编写于 作者: Z Zhen Wang 提交者: GitHub

Update the batch size used in test_resnet50_with_cinn.py. (#37013)

* Update the batch size used in test_resnet50_with_cinn.py.
* Enable more debug info.
上级 1653f99f
...@@ -27,7 +27,7 @@ add_definitions(-w) ...@@ -27,7 +27,7 @@ add_definitions(-w)
include(ExternalProject) include(ExternalProject)
set(CINN_SOURCE_DIR ${THIRD_PARTY_PATH}/CINN) set(CINN_SOURCE_DIR ${THIRD_PARTY_PATH}/CINN)
# TODO(zhhsplendid): Modify git tag after we have release tag # TODO(zhhsplendid): Modify git tag after we have release tag
set(CINN_GIT_TAG 2122413fc74f4020ff4397b54488a793529d581b) set(CINN_GIT_TAG develop)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} -DPUBLISH_LIBS=ON -DWITH_TESTING=ON)
set(CINN_BUILD_COMMAND $(MAKE) cinnapi -j) set(CINN_BUILD_COMMAND $(MAKE) cinnapi -j)
ExternalProject_Add( ExternalProject_Add(
......
...@@ -67,7 +67,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -67,7 +67,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph, const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) { const Target& target) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph); VLOG(1) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
bool exist = false; bool exist = false;
{ {
...@@ -195,7 +195,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -195,7 +195,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target); frontend_program, target);
VLOG(4) << "-- The " << compiled_num << "-th compilation (" VLOG(1) << "-- The " << compiled_num << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n" << target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize(); << cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion"); ApplyPass(cinn_graph.get(), "OpFusion");
......
...@@ -40,8 +40,8 @@ def set_cinn_flag(val): ...@@ -40,8 +40,8 @@ def set_cinn_flag(val):
class TestResnet50Accuracy(unittest.TestCase): class TestResnet50Accuracy(unittest.TestCase):
def reader(self, limit): def reader(self, limit):
for _ in range(limit): for _ in range(limit):
yield np.random.randint(0, 256, size=[1, 3, 224, 224]).astype('float32'), \ yield np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \
np.random.randint(0, 1000, size=[1]).astype('int64') np.random.randint(0, 1000, size=[32]).astype('int64')
def generate_random_data(self, loop_num=10): def generate_random_data(self, loop_num=10):
feed = [] feed = []
...@@ -54,8 +54,8 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -54,8 +54,8 @@ class TestResnet50Accuracy(unittest.TestCase):
def build_program(self, main_program, startup_program): def build_program(self, main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[1, 3, 224, 224], dtype='float32') name='image', shape=[32, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[1], dtype='int64') label = paddle.static.data(name='label', shape=[32], dtype='int64')
model = paddle.vision.models.resnet50() model = paddle.vision.models.resnet50()
prediction = model(image) prediction = model(image)
...@@ -80,7 +80,7 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -80,7 +80,7 @@ class TestResnet50Accuracy(unittest.TestCase):
loss = self.build_program(main_program, startup_program) loss = self.build_program(main_program, startup_program)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
parallel_exec = paddle.static.CompiledProgram( compiled_prog = paddle.static.CompiledProgram(
main_program).with_data_parallel(loss_name=loss.name) main_program).with_data_parallel(loss_name=loss.name)
loss_vals = [] loss_vals = []
scope = paddle.static.Scope() scope = paddle.static.Scope()
...@@ -88,7 +88,7 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -88,7 +88,7 @@ class TestResnet50Accuracy(unittest.TestCase):
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
exe.run(startup_program) exe.run(startup_program)
for step in range(iters): for step in range(iters):
loss_v = exe.run(parallel_exec, loss_v = exe.run(compiled_prog,
feed=feed[step], feed=feed[step],
fetch_list=[loss], fetch_list=[loss],
return_numpy=True) return_numpy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册