未验证 提交 5690666c 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add use_cinn Flag and RunFromCinn in PE (#36107)

Add use_cinn flag and use it to control whether we run PaddlePaddle using CINN.

Also add:

Replace PaddlePaddle graph with a CINN graph in a pass
PE Method to feed data and run the graph by CINN
上级 9b987b3d
...@@ -351,7 +351,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h ...@@ -351,7 +351,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy bind_threaded_ssa_graph_executor collective_helper graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper cinn_runner)
cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor) cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
if(WITH_PSCORE) if(WITH_PSCORE)
......
...@@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass ...@@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
fix_op_run_order_pass) paddle_to_cinn_pass fix_op_run_order_pass)
if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif() endif()
......
...@@ -19,8 +19,9 @@ limitations under the License. */ ...@@ -19,8 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_printer.h" #include "paddle/fluid/framework/ir/graph_printer.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
DECLARE_bool(use_mkldnn);
DECLARE_bool(convert_all_blocks); DECLARE_bool(convert_all_blocks);
DECLARE_bool(use_cinn);
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -71,6 +72,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -71,6 +72,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right. // Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
// Note: This pass is used to enable cinn.
if (FLAGS_use_cinn) {
AppendPass("paddle_to_cinn_pass");
}
SetCollectiveContext(); SetCollectiveContext();
} }
......
...@@ -59,6 +59,7 @@ cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) ...@@ -59,6 +59,7 @@ cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper) cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(paddle_to_cinn_pass base DEPS cinn_runner)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
...@@ -142,6 +143,7 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) ...@@ -142,6 +143,7 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(paddle_to_cinn_pass_test SRCS paddle_to_cinn_pass_test.cc DEPS paddle_to_cinn_pass proto_desc)
cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry) cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass) cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/paddle_to_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
namespace paddle {
namespace framework {
namespace ir {
void PaddleToCinnPass::ApplyImpl(ir::Graph* graph) const {
paddle2cinn::CinnRunner::GetInstance()->ReplaceWithCinn(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(paddle_to_cinn_pass, paddle::framework::ir::PaddleToCinnPass);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class PaddleToCinnPass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/paddle_to_cinn_pass.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(PaddleToCinnPassTest, TodoTest) {
ProgramDesc program;
Graph graph(program);
auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
"paddle_to_cinn_pass");
pass->Apply(&graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(paddle_to_cinn_pass);
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" #include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
#include <map> #include <map>
#include <memory>
#include <mutex>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -26,6 +28,19 @@ namespace paddle2cinn { ...@@ -26,6 +28,19 @@ namespace paddle2cinn {
using ir::Graph; using ir::Graph;
std::once_flag CinnRunner::get_instance_once_flag_;
std::shared_ptr<CinnRunner> CinnRunner::instance_;
std::shared_ptr<CinnRunner> CinnRunner::GetInstance() {
std::call_once(get_instance_once_flag_,
[&]() { instance_.reset(new CinnRunner()); });
return instance_;
}
void CinnRunner::ReplaceWithCinn(Graph* graph) {
// TODO(zhhsplendid): call CINN Api when it is ready
}
std::map<std::string, FetchType*> CinnRunner::Run( std::map<std::string, FetchType*> CinnRunner::Run(
const Graph& graph, Scope* scope, const Graph& graph, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets) { std::map<std::string, const LoDTensor*>* feed_targets) {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -36,15 +37,24 @@ namespace paddle2cinn { ...@@ -36,15 +37,24 @@ namespace paddle2cinn {
// cache. // cache.
class CinnRunner { class CinnRunner {
public: public:
CinnRunner() {}
~CinnRunner() {} ~CinnRunner() {}
// Singleton
static std::shared_ptr<CinnRunner> GetInstance();
// Replace Paddle graph with some CINN subgraphs/ops
void ReplaceWithCinn(ir::Graph* graph);
// Feed LoDTensors to tun CINN compiled object and return fetched result // Feed LoDTensors to tun CINN compiled object and return fetched result
std::map<std::string, FetchType*> Run( std::map<std::string, FetchType*> Run(
const ir::Graph& graph, Scope* scope, const ir::Graph& graph, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets); std::map<std::string, const LoDTensor*>* feed_targets);
private: private:
CinnRunner() {}
static std::once_flag get_instance_once_flag_;
static std::shared_ptr<CinnRunner> instance_;
std::unordered_map<CinnCacheKey, std::shared_ptr<CinnCompiledObject>, std::unordered_map<CinnCacheKey, std::shared_ptr<CinnCompiledObject>,
CinnCacheKey::Hash> CinnCacheKey::Hash>
cache_; cache_;
......
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -32,8 +34,9 @@ TEST(CinnRunnerTest, TodoTest) { ...@@ -32,8 +34,9 @@ TEST(CinnRunnerTest, TodoTest) {
Scope empty_scope; Scope empty_scope;
std::map<std::string, const LoDTensor*> empty_feed; std::map<std::string, const LoDTensor*> empty_feed;
CinnRunner cinn_runner; std::shared_ptr<CinnRunner> cinn_runner = CinnRunner::GetInstance();
cinn_runner.Run(empty_graph, &empty_scope, &empty_feed); cinn_runner->ReplaceWithCinn(&empty_graph);
cinn_runner->Run(empty_graph, &empty_scope, &empty_feed);
} }
} // namespace paddle2cinn } // namespace paddle2cinn
......
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/event.h"
...@@ -43,6 +44,7 @@ limitations under the License. */ ...@@ -43,6 +44,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
DECLARE_bool(use_cinn);
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -943,6 +945,40 @@ void ParallelExecutor::RunWithoutFetch( ...@@ -943,6 +945,40 @@ void ParallelExecutor::RunWithoutFetch(
member_->executor_->Run(/*fetch_tensors*/ {}, /*return_merged*/ false); member_->executor_->Run(/*fetch_tensors*/ {}, /*return_merged*/ false);
} }
FetchResultType ParallelExecutor::RunFromCinn(
const std::unordered_map<std::string, LoDTensor> &feed_tensors,
const std::vector<std::string> &fetch_names) {
// Feed tensor to scope, now only support 1 scope
// TODO(zhhsplendid): handle multiple scope
size_t scope_id = 0;
std::map<std::string, const LoDTensor *> cinn_input_tensors;
for (auto &name_tensor_pair : feed_tensors) {
bool is_persistable = member_->IsPersistable(name_tensor_pair.first);
if (!is_persistable) {
member_->SetSkipMemoryReuse(scope_id, name_tensor_pair.first);
}
Scope *feed_scope = is_persistable ? member_->local_scopes_[scope_id]
: member_->local_exec_scopes_[scope_id];
Variable *feed_var = feed_scope->Var(name_tensor_pair.first);
LoDTensor *trg = feed_var->GetMutable<LoDTensor>();
trg->ShareDataWith(name_tensor_pair.second);
trg->set_lod(name_tensor_pair.second.lod());
cinn_input_tensors[name_tensor_pair.first] = trg;
}
// TODO(zhhsplendid): get correct API after CINN API is ready
// now only return empty fetch result;
std::shared_ptr<paddle2cinn::CinnRunner> cinn_runner =
paddle2cinn::CinnRunner::GetInstance();
cinn_runner->Run(Graph(), member_->local_exec_scopes_[scope_id],
&cinn_input_tensors);
paddle::framework::FetchResultType fetches = FetchList(fetch_names.size());
return fetches;
}
void ParallelExecutor::SkipMemoryReuse( void ParallelExecutor::SkipMemoryReuse(
size_t scope_idx, const std::vector<std::string> &skip_vars) { size_t scope_idx, const std::vector<std::string> &skip_vars) {
for (auto &var_name : skip_vars) { for (auto &var_name : skip_vars) {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -92,6 +93,10 @@ class ParallelExecutor { ...@@ -92,6 +93,10 @@ class ParallelExecutor {
void RunWithoutFetch(const std::vector<std::string> &skip_eager_vars); void RunWithoutFetch(const std::vector<std::string> &skip_eager_vars);
FetchResultType RunFromCinn(
const std::unordered_map<std::string, LoDTensor> &feed_tensors,
const std::vector<std::string> &fetch_names);
void ResetOpHandleScopeMapOfGraphs( void ResetOpHandleScopeMapOfGraphs(
const std::unordered_map<Scope *, Scope *> &scope_map); const std::unordered_map<Scope *, Scope *> &scope_map);
......
...@@ -681,6 +681,16 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -681,6 +681,16 @@ PADDLE_DEFINE_EXPORTED_bool(
apply_pass_to_program, false, apply_pass_to_program, false,
"It controls whether to apply IR pass to program when using Fleet APIs"); "It controls whether to apply IR pass to program when using Fleet APIs");
/**
* CINN related FLAG
* Name: FLAGS_use_cinn
* Since Version: 2.3
* Value Range: bool, default=false
* Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN
*/
PADDLE_DEFINE_EXPORTED_bool(
use_cinn, false, "It controls whether to run PaddlePaddle using CINN");
DEFINE_int32(record_pool_max_size, 2000000, DEFINE_int32(record_pool_max_size, 2000000,
"SlotRecordDataset slot record pool max size"); "SlotRecordDataset slot record pool max size");
DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num"); DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num");
......
...@@ -3293,6 +3293,18 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3293,6 +3293,18 @@ All parameter, weight, gradient are variables in Paddle.
BOOST_GET(paddle::framework::FetchUnmergedList, ret))); BOOST_GET(paddle::framework::FetchUnmergedList, ret)));
} }
}) })
.def("run_from_cinn",
[](ParallelExecutor &self,
const std::unordered_map<std::string, LoDTensor> &feed_tensors,
const std::vector<std::string> &fetch_names) -> py::object {
paddle::framework::FetchResultType ret;
{
pybind11::gil_scoped_release release;
ret = self.RunFromCinn(feed_tensors, fetch_names);
}
return py::cast(
std::move(BOOST_GET(paddle::framework::FetchList, ret)));
})
.def("device_count", &ParallelExecutor::DeviceCount); .def("device_count", &ParallelExecutor::DeviceCount);
BindFleetWrapper(&m); BindFleetWrapper(&m);
......
...@@ -23,7 +23,8 @@ import numpy as np ...@@ -23,7 +23,8 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
import six import six
from .data_feeder import convert_dtype from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, Operator, convert_np_dtype_to_dtype_ from .framework import Program, default_main_program, Variable, Operator
from .framework import convert_np_dtype_to_dtype_, get_flags
from . import core from . import core
from . import unique_name from . import unique_name
from . import compiler from . import compiler
...@@ -1016,6 +1017,15 @@ class Executor(object): ...@@ -1016,6 +1017,15 @@ class Executor(object):
check_feed_shape_type(var, feed_tensor, exe.device_count()) check_feed_shape_type(var, feed_tensor, exe.device_count())
feed_tensor_dict[feed_name] = feed_tensor feed_tensor_dict[feed_name] = feed_tensor
#TODO(zhhsplendid): handle other feed data format case for CINN
use_cinn = get_flags("FLAGS_use_cinn")["FLAGS_use_cinn"]
if use_cinn:
fetch_var_names = list(map(_to_name_str, fetch_list))
fetch_tensors = exe.run_from_cinn(
feed_tensor_dict, fetch_var_names)._move_to_list()
return as_numpy(
fetch_tensors) if return_numpy else fetch_tensors
else:
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple): elif isinstance(feed, list) or isinstance(feed, tuple):
res = list() res = list()
...@@ -1036,6 +1046,8 @@ class Executor(object): ...@@ -1036,6 +1046,8 @@ class Executor(object):
check_feed_shape_type(var, tensor) check_feed_shape_type(var, tensor)
res_dict[feed_name] = tensor res_dict[feed_name] = tensor
res.append(res_dict) res.append(res_dict)
use_cinn = get_flags("FLAGS_use_cinn")["FLAGS_use_cinn"]
exe.feed_tensors_into_local_scopes(res) exe.feed_tensors_into_local_scopes(res)
if hasattr(program._program, 'lr_sheduler'): if hasattr(program._program, 'lr_sheduler'):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import unittest
paddle.enable_static()
class TestParallelExecutorRunCinn(unittest.TestCase):
def test_run_from_cinn(self):
paddle.set_flags({'FLAGS_use_cinn': True})
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(
name='X', shape=[None, 1], dtype='float32')
prediction = paddle.static.nn.fc(data, 2)
loss = paddle.mean(prediction)
adam = paddle.optimizer.Adam()
adam.minimize(loss)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
compiled_program = paddle.static.CompiledProgram(
main_program).with_data_parallel(loss_name=loss.name)
batch_size = 16
x = np.random.random(size=(batch_size, 1)).astype('float32')
fetch = exe.run(compiled_program,
feed={'X': x},
fetch_list=[prediction.name],
return_merged=False)
paddle.set_flags({'FLAGS_use_cinn': False})
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册