未验证 提交 568cc2ff 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize while_op for test (#14764)

* Simplify the compare op for CPU.

* Use asynchronous tensor copy in reshape_op's kernel.

* Optimize while_op for test, avoiding creating variables every time.
test=develop

* Enable the cache of kernel type and kernel function.
test=develop

* Enable profiling with gperftools.

* Remove flags for testing, and fix the linking error.
test=develop

* Delete the codes of ChooseKernel.
test=develop

* Fix bug when preparing ExecutorPrepareContext for while_op.

* Fix missing depending on grpc libraries.

* Remove the redundant print.
test=develop

* Follow comments.

* Remove the codes related to prepare the ExecutorPrepareContext for while_op.
test=develop
上级 3759c1db
...@@ -91,9 +91,9 @@ ...@@ -91,9 +91,9 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE AND NOT ANDROID) if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt")
endif(NOT APPLE AND NOT ANDROID) endif(NOT APPLE AND NOT ANDROID)
set_property(GLOBAL PROPERTY FLUID_MODULES "") set_property(GLOBAL PROPERTY FLUID_MODULES "")
...@@ -304,7 +304,7 @@ function(cc_library TARGET_NAME) ...@@ -304,7 +304,7 @@ function(cc_library TARGET_NAME)
if(cc_library_DEPS) if(cc_library_DEPS)
merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) merge_static_libs(${TARGET_NAME} ${cc_library_DEPS})
else() else()
message(FATAL "Please specify source file or library in cc_library.") message(FATAL_ERROR "Please specify source files or libraries in cc_library(${TARGET_NAME} ...).")
endif() endif()
endif(cc_library_SRCS) endif(cc_library_SRCS)
endfunction(cc_library) endfunction(cc_library)
......
if(WITH_TESTING) if(WITH_TESTING)
include(tests/test.cmake) # some generic cmake funtion for inference include(tests/test.cmake) # some generic cmake funtion for inference
endif() endif()
# analysis and tensorrt must be added before creating static library,
# otherwise, there would be undefined reference to them in static library.
add_subdirectory(analysis)
add_subdirectory(utils)
if (TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor) set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
...@@ -16,6 +9,14 @@ cc_library(paddle_fluid_api ...@@ -16,6 +9,14 @@ cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS})
# analysis and tensorrt must be added before creating static library,
# otherwise, there would be undefined reference to them in static library.
add_subdirectory(analysis)
add_subdirectory(utils)
if (TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
get_property(fluid_third_partys GLOBAL PROPERTY FLUID_THRID_PARTYS) get_property(fluid_third_partys GLOBAL PROPERTY FLUID_THRID_PARTYS)
...@@ -40,10 +41,10 @@ set(SHARED_INFERENCE_SRCS ...@@ -40,10 +41,10 @@ set(SHARED_INFERENCE_SRCS
if(WIN32) if(WIN32)
sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array
analysis_config paddle_pass_builder) analysis_config paddle_pass_builder)
else(WIN32) else(WIN32)
cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS}
analysis_config paddle_pass_builder) zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder)
endif(WIN32) endif(WIN32)
if(NOT APPLE) if(NOT APPLE)
...@@ -55,11 +56,13 @@ endif() ...@@ -55,11 +56,13 @@ endif()
# Create shared library # Create shared library
if(WIN32) if(WIN32)
sep_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} sep_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array analysis_config paddle_pass_builder) DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array
analysis_config paddle_pass_builder)
target_link_libraries(paddle_fluid_shared shlwapi) target_link_libraries(paddle_fluid_shared shlwapi)
else(WIN32) else(WIN32)
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array analysis_config paddle_pass_builder) DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array
analysis_config paddle_pass_builder)
endif() endif()
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
......
...@@ -18,21 +18,22 @@ if(APPLE) ...@@ -18,21 +18,22 @@ if(APPLE)
endif(APPLE) endif(APPLE)
set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor analysis_predictor ${GLOB_PASS_LIB}) set(inference_deps paddle_inference_api paddle_fluid_api analysis pass
ir_pass_manager naive_executor analysis_predictor ${GLOB_PASS_LIB})
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter) set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter)
endif() endif()
cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) add_subdirectory(details)
cc_library(analysis_config SRCS analysis_config.cc DEPS lod_tensor paddle_pass_builder) cc_library(analysis_config SRCS analysis_config.cc DEPS lod_tensor paddle_pass_builder)
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager)
cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS scope lod_tensor enforce)
cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS
lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config
analysis_config paddle_pass_builder zero_copy_tensor reset_tensor_array) analysis_config paddle_pass_builder zero_copy_tensor
reset_tensor_array)
cc_test(test_paddle_inference_api cc_test(test_paddle_inference_api
SRCS api_tester.cc SRCS api_tester.cc
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce)
cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc)
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#ifdef WITH_GPERFTOOLS
#include <gperftools/profiler.h>
#endif
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -215,13 +218,19 @@ void TestOneThreadPrediction( ...@@ -215,13 +218,19 @@ void TestOneThreadPrediction(
{ {
Timer run_timer; Timer run_timer;
run_timer.tic(); run_timer.tic();
#ifdef WITH_GPERFTOOLS
ProfilerStart("paddle_inference.prof");
#endif
for (int i = 0; i < num_times; i++) { for (int i = 0; i < num_times; i++) {
for (size_t j = 0; j < inputs.size(); j++) { for (size_t j = 0; j < inputs.size(); j++) {
predictor->Run(inputs[j], outputs, batch_size); predictor->Run(inputs[j], outputs, batch_size);
} }
} }
#ifdef WITH_GPERFTOOLS
ProfilerStop();
#endif
double latency = run_timer.toc() / num_times; double latency = run_timer.toc() / (num_times > 1 ? num_times : 1);
PrintTime(batch_size, num_times, 1, 0, latency, inputs.size()); PrintTime(batch_size, num_times, 1, 0, latency, inputs.size());
if (FLAGS_record_benchmark) { if (FLAGS_record_benchmark) {
Benchmark benchmark; Benchmark benchmark;
......
...@@ -18,6 +18,30 @@ limitations under the License. */ ...@@ -18,6 +18,30 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Functor>
class CompareOpKernel<platform::CPUDeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis");
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = z->mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
} else {
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>(
context, x, y, axis, Functor(), z);
}
}
};
template <typename OpComment> template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -51,7 +75,7 @@ calculated by $%s$ ...@@ -51,7 +75,7 @@ calculated by $%s$
template <typename OpComment> template <typename OpComment>
class CompareOpInferShape : public framework::InferShapeBase { class CompareOpInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext* context) const override {
OpComment comment; OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X", PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X",
comment.type); comment.type);
...@@ -73,7 +97,7 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -73,7 +97,7 @@ class CompareOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place // CompareOp kernel's device type is decided by input tensor place
bool force_cpu = ctx.Attr<bool>("force_cpu"); bool force_cpu = ctx.Attr<bool>("force_cpu");
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -58,6 +58,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -58,6 +58,7 @@ class WhileOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>(); auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
...@@ -72,18 +73,27 @@ class WhileOp : public framework::OperatorBase { ...@@ -72,18 +73,27 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
bool is_test = Attr<bool>("is_test");
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
bool is_test = Attr<bool>("is_test");
auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
while (cond.data<bool>()[0]) {
if (!is_test) {
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
true);
}
} else {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); executor.CreateVariables(*program, &current_scope, block->ID());
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true); while (cond.data<bool>()[0]) {
if (is_test) { executor.RunPreparedContext(ctx.get(), &current_scope, false, false,
scope.DeleteScope(&current_scope); false);
} }
scope.DeleteScope(&current_scope);
} }
} }
}; };
......
...@@ -12,6 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto @O ...@@ -12,6 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto @O
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC) if(WITH_GRPC)
set(GRPC_DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(GRPC_SRCS grpc/grpc_client.cc grpc/grpc_server.cc grpc/grpc_serde.cc grpc/grpc_bytebuffer_stream.cc grpc/grpc_variable_response.cc) set(GRPC_SRCS grpc/grpc_client.cc grpc/grpc_server.cc grpc/grpc_serde.cc grpc/grpc_bytebuffer_stream.cc grpc/grpc_variable_response.cc)
grpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc grpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc request_handler_impl.cc rpc_client.cc rpc_server.cc
...@@ -19,10 +20,10 @@ if(WITH_GRPC) ...@@ -19,10 +20,10 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc collective_client.cc collective_server.cc
${GRPC_SRCS} ${GRPC_SRCS}
PROTO ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto PROTO ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto
DEPS lod_tensor selected_rows_functor memory) DEPS lod_tensor selected_rows_functor memory ${GRPC_DEPS})
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
cc_test(grpc_serde_test SRCS grpc/grpc_serde_test.cc cc_test(grpc_serde_test SRCS grpc/grpc_serde_test.cc
DEPS ${RPC_DEPS} scope profiler math_function SERIAL) DEPS ${RPC_DEPS} scope profiler math_function SERIAL)
......
...@@ -226,7 +226,9 @@ class ReshapeKernel { ...@@ -226,7 +226,9 @@ class ReshapeKernel {
} }
out->mutable_data(ctx.GetPlace(), in->type()); out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopySync(*in, ctx.GetPlace(), out); framework::TensorCopy(
*in, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims); out->Resize(out_dims);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册