提交 bdc82956 编写于 作者: Y Yang Yu

Merge branch 'develop' of github.com:baidu/Paddle into feature/make_lod_a_share_ptr

......@@ -63,9 +63,30 @@ ExternalProject_Add(
-DMKLROOT:PATH=${MKLML_ROOT}
)
ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
ADD_LIBRARY(shared_mkldnn SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET shared_mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB})
ADD_DEPENDENCIES(shared_mkldnn ${MKLDNN_PROJECT})
MESSAGE(STATUS "MKLDNN library: ${MKLDNN_LIB}")
add_definitions(-DPADDLE_WITH_MKLDNN)
LIST(APPEND external_project_dependencies mkldnn)
LIST(APPEND external_project_dependencies shared_mkldnn)
# generate a static dummy target to track mkldnn dependencies
# for cc_library(xxx SRCS xxx.c DEPS mkldnn)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/mkldnn_dummy.c)
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
ADD_LIBRARY(mkldnn STATIC ${dummyfile})
TARGET_LINK_LIBRARIES(mkldnn ${MKLDNN_LIB} ${MKLML_LIB} ${MKLML_IOMP_LIB})
ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT})
# copy the real so.0 lib to install dir
# it can be directly contained in wheel or capi
SET(MKLDNN_SHARED_LIB ${MKLDNN_INSTALL_DIR}/libmkldnn.so.0)
ADD_CUSTOM_COMMAND(OUTPUT ${MKLDNN_SHARED_LIB}
COMMAND cp ${MKLDNN_LIB} ${MKLDNN_SHARED_LIB}
DEPENDS mkldnn)
ADD_CUSTOM_TARGET(mkldnn_shared_lib ALL DEPENDS ${MKLDNN_SHARED_LIB})
IF(WITH_C_API)
INSTALL(FILES ${MKLDNN_SHARED_LIB} DESTINATION lib)
ENDIF()
......@@ -66,3 +66,7 @@ ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB})
ADD_DEPENDENCIES(mklml ${MKLML_PROJECT})
LIST(APPEND external_project_dependencies mklml)
IF(WITH_C_API)
INSTALL(FILES ${MKLML_LIB} ${MKLML_IOMP_LIB} DESTINATION lib)
ENDIF()
......@@ -7,11 +7,9 @@ PaddlePaddle每次发新的版本,遵循以下流程:
1.`develop`分支派生出新的分支,分支名为`release/版本号`。例如,`release/0.10.0`
1. 将新分支的版本打上tag,tag为`版本号rc.Patch号`。第一个tag为`0.10.0rc1`,第二个为`0.10.0rc2`,依次类推。
1. 对这个版本的提交,做如下几个操作:
* 使用Regression Test List作为检查列表,测试本次release的正确性。
* 如果失败,记录下所有失败的例子,在这个`release/版本号`分支中,修复所有bug后,Patch号加一,到第二步
* 修改`python/setup.py.in`中的版本信息,并将`istaged`字段设为`True`
* 编译这个版本的Docker发行镜像,发布到dockerhub。如果失败,修复Docker编译镜像问题,Patch号加一,返回第二步
* 编译这个版本的Ubuntu Deb包。如果失败,修复Ubuntu Deb包编译问题,Patch号加一,返回第二步。
* 使用Regression Test List作为检查列表,测试Docker镜像/ubuntu安装包的功能正确性
* 如果失败,记录下所有失败的例子,在这个`release/版本号`分支中,修复所有bug后,Patch号加一,返回第二步
* 编译这个版本的python wheel包,并发布到pypi。
* 由于pypi.python.org目前遵循[严格的命名规范PEP 513](https://www.python.org/dev/peps/pep-0513),在使用twine上传之前,需要重命名wheel包中platform相关的后缀,比如将`linux_x86_64`修改成`manylinux1_x86_64`
* pypi上的package名称为paddlepaddle和paddlepaddle_gpu,如果要上传GPU版本的包,需要修改build/python/setup.py中,name: "paddlepaddle_gpu"并重新打包wheel包:`python setup.py bdist_wheel`
......@@ -21,8 +19,8 @@ PaddlePaddle每次发新的版本,遵循以下流程:
pip install twine
twine upload dist/[package to upload]
```
* 编译这个版本的Docker发行镜像,发布到dockerhub。如果失败,修复Docker编译镜像问题,Patch号加一,返回第二步
1. 第三步完成后,将`release/版本号`分支合入master分支,并删除`release/版本号`分支。将master分支的合入commit打上tag,tag为`版本号`。同时再将`master`分支合入`develop`分支。最后删除`release/版本号`分支。
1. 编译master分支的Docker发行镜像,发布到dockerhub。编译ubuntu的deb包,发布到github release页面
1. 协同完成Release Note的书写
......@@ -31,6 +29,30 @@ PaddlePaddle每次发新的版本,遵循以下流程:
* `release/版本号`分支一旦建立,一般不允许再从`develop`分支合入`release/版本号`。这样保证`release/版本号`分支功能的封闭,方便测试人员测试PaddlePaddle的行为。
*`release/版本号`分支存在的时候,如果有bugfix的行为,需要将bugfix的分支同时merge到`master`, `develop``release/版本号`这三个分支。
## 发布wheel包到pypi
使用[PaddlePaddle CI](https://paddleci.ngrok.io/project.html?projectId=Manylinux1&tab=projectOverview)
完成自动化二进制编译,参考下图,选择需要发布的版本(通常包含一个CPU版本和一个GPU版本),点击"run"右侧的"..."按钮,可以
弹出下面的选择框,在第二个tab (Changes)里选择需要发布的分支,这里选择0.11.0,然后点击"Run Build"按钮。等待编译完成后
可以在此页面的"Artifacts"下拉框中找到生成的3个二进制文件,分别对应CAPI,`cp27m``cp27mu`的版本。然后按照上述的方法
使用`twine`工具上传即可。
<img src="ci_build_whl.png">
* 注:CI环境使用 https://github.com/PaddlePaddle/buildtools 这里的DockerImage作为编译环境以支持更多的Linux
发型版,如果需要手动编译,也可以使用这些镜像。这些镜像也可以从 https://hub.docker.com/r/paddlepaddle/paddle_manylinux_devel/tags/ 下载得到。
* pypi不支持覆盖上传,所以一个版本号的wheel包发布之后,不可以更改。下一个wheel包需要更新版本号才可以上传。
## 发布Docker镜像
上述PaddlePaddle CI编译wheel完成后会自动将Docker镜像push到DockerHub,所以,发布Docker镜像只需要对自动push的镜像打上
版本号对应的tag即可:
1. 进入 https://hub.docker.com/r/paddlepaddle/paddle/tags/ 查看latest tag的更新时间是否在上述编译wheel包完成后是否最新。
1. 执行 `docker pull paddlepaddle/paddle:[latest tag]`,latest tag可以是latest或latest-gpu等。
1. 执行 `docker tag paddlepaddle/paddle:[latest tag] paddlepaddle/paddle:[version]`
1. 执行 `docker push paddlepaddle/paddle:[version]`
## PaddlePaddle 分支规范
PaddlePaddle开发过程使用[git-flow](http://nvie.com/posts/a-successful-git-branching-model/)分支规范,并适应github的特性做了一些区别。
......
......@@ -32,7 +32,9 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto)
cc_library(device_data_transform SRCS device_data_transform.cc DEPS tensor)
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows device_data_transform)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
......@@ -41,7 +43,7 @@ device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
......@@ -73,9 +75,10 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operator)
cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
nv_test(device_data_transform_test SRCS device_data_transform_test.cu
DEPS operator op_registry init math_function)
......@@ -427,7 +427,8 @@ std::vector<std::unique_ptr<OpDesc>> MakeBlockBackward(
VLOG(5) << "Making backward " << (*it)->Type() << " op";
std::vector<std::unique_ptr<OpDesc>> op_grads;
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while" ||
(*it)->Type() == "parallel_do") {
int step_block_idx = (*it)->GetBlockAttr("sub_block");
BlockDesc* backward_block = CreateStepBlock(program_desc, no_grad_vars,
grad_to_var, step_block_idx);
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#include <functional>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
namespace paddle {
......@@ -25,6 +27,37 @@ DataTransformFnMap& DataTransformFnMap::Instance() {
return data_transform_map;
}
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
Tensor* out = nullptr;
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
return out;
}
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var) {
if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var.GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var.GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
} else {
PADDLE_THROW("unknown var type");
}
}
auto KernelFP32 = OpKernelType(proto::DataType::FP32, platform::CPUPlace(),
DataLayout::kNHWC, LibraryType::kPlain);
......@@ -37,6 +70,28 @@ auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
DataLayout::kNCHW, LibraryType::kPlain);
// TODO(dzhwinter): Only for testing multiple op kernel.
// Dummy transform function for library_type
// should be removed.
auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout, LibraryType::kPlain);
auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout, LibraryType::kCUDNN);
void DummyTrans(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
*dst = src;
}
void TransDataType(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
......@@ -121,6 +176,8 @@ std::vector<int> NCHW2NHWC = {0, 2, 3, 1};
}
REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType);
REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans);
REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans);
REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW,
std::bind(f::TransDataLayout, NHWC2NCHW,
std::placeholders::_1,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/math/math_function.h"
......@@ -49,6 +50,13 @@ struct KernelTypePairHash {
}
};
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);
template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/device_data_transform.h"
namespace paddle {
namespace framework {
static const platform::DeviceContext* GetDeviceContext(
const platform::Place& src_place, const platform::Place& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
Tensor* out = new Tensor();
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
CopyFrom(in, dst_place, *dev_ctx, out);
dev_ctx->Wait();
return out;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input1 of test op");
AddOutput("output", "output of test op");
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);
AddComment("This is test op");
}
};
class TestOpWithKernel : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel";
return OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0));
} else {
VLOG(3) << "use default kernel";
return OpKernelType(proto::DataType::FP32,
ctx.Input<Tensor>("input")->place());
}
}
};
template <typename DeviceContext, typename T>
class TestKernel : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;
const Tensor* input = ctx.Input<Tensor>("input");
std::cout << "input place:" << input->place() << std::endl;
auto* output = ctx.Output<framework::LoDTensor>("output");
output->Resize(input->dims());
output->mutable_data<T>(ctx.GetPlace());
operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
input, input, output, ctx.template device_context<DeviceContext>(),
AddFunctor<T>());
functor.Run();
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
test_op, paddle::framework::TestOpWithKernel,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
test_op,
paddle::framework::TestKernel<paddle::platform::CUDADeviceContext, float>);
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
TEST(Operator, CPUtoGPU) {
using namespace paddle::framework;
using namespace paddle::platform;
ASSERT_EQ(InitDevices({"CPU", "GPU:0"}), true);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
// create an op to run on CPU
paddle::framework::proto::OpDesc cpu_op_desc;
cpu_op_desc.set_type("test_op");
BuildVar("input", {"IN1"}, cpu_op_desc.add_inputs());
BuildVar("output", {"OUT1"}, cpu_op_desc.add_outputs());
auto cpu_op = paddle::framework::OpRegistry::CreateOp(cpu_op_desc);
// prepare input
auto* in_t = scope.Var("IN1")->GetMutable<LoDTensor>();
auto* src_ptr = in_t->mutable_data<float>({2, 3}, CPUPlace());
for (int i = 0; i < 2 * 3; ++i) {
src_ptr[i] = static_cast<float>(i);
}
// get output
auto* output = scope.Var("OUT1");
cpu_op->Run(scope, cpu_place);
auto* output_ptr = output->Get<LoDTensor>().data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output_ptr[i], static_cast<float>(i) * 2);
}
// create an op to run on GPU
paddle::framework::proto::OpDesc gpu_op_desc;
gpu_op_desc.set_type("test_op");
BuildVar("input", {"OUT1"}, gpu_op_desc.add_inputs());
BuildVar("output", {"OUT2"}, gpu_op_desc.add_outputs());
auto attr = gpu_op_desc.mutable_attrs()->Add();
attr->set_name("use_gpu");
attr->set_type(paddle::framework::proto::AttrType::BOOLEAN);
attr->set_b(true);
auto gpu_op = paddle::framework::OpRegistry::CreateOp(gpu_op_desc);
paddle::platform::CUDAPlace cuda_place(0);
// get output
auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place);
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
DeviceContextPool& pool = DeviceContextPool::Instance();
auto dev_ctx = pool.Get(cuda_place);
paddle::framework::Tensor output_tensor;
CopyFrom(output2->Get<LoDTensor>(), paddle::platform::CPUPlace(), *dev_ctx,
&output_tensor);
dev_ctx->Wait();
float* output2_ptr = output_tensor.data<float>();
for (int i = 0; i < 2 * 3; ++i) {
ASSERT_EQ(output2_ptr[i], static_cast<float>(i) * 4);
}
}
......@@ -111,7 +111,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
VLOG(3) << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include "paddle/framework/init.h"
#include "paddle/framework/operator.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/string/piece.h"
......@@ -24,7 +25,6 @@ namespace framework {
std::once_flag gflags_init_flag;
// TODO(qijun) move init gflags to init.cc
void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() {
int argc = argv.size();
......@@ -72,6 +72,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
platform::DeviceContextPool::Init(places);
framework::UseALL();
return true;
}
......
......@@ -43,6 +43,22 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
return os;
}
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
PADDLE_ENFORCE(platform::is_cpu_place(t.place()));
PADDLE_ENFORCE(t.type().hash_code() == typeid(float).hash_code());
os << "dim: " << t.dims() << "\n";
os << "lod: " << t.lod() << "\n";
// only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) {
os << t.data<float>()[i] << " ";
}
return os;
}
LoD SliceLevels(const LoD &in, size_t level_begin, size_t level_end) {
LoD new_lod;
new_lod.reserve(level_end - level_begin);
......@@ -244,5 +260,69 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
DeserializeFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
}
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const {
check_memory_size();
// PADDLE_ENFORCE(lod().empty() || (lod().size() == 1 && lod()[0].empty())
// , "Disable parallel lod for now");
PADDLE_ENFORCE(lod().empty(), "Disable parallel lod for now");
PADDLE_ENFORCE(dims()[0] % places.size() == 0,
"Batch size should be divided by places size");
std::vector<LoDTensor> lods;
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
size_t begin = place_idx * dims()[0] / places.size();
size_t end = (place_idx + 1) * dims()[0] / places.size();
auto src = Slice(static_cast<int>(begin), static_cast<int>(end));
LoDTensor dst;
dst.Resize(src.dims());
auto &dst_place = places[place_idx];
auto dst_ptr = dst.mutable_data(dst_place, src.type());
// TODO(tonyyang-svail):
// change the following to framework::CopyFrom
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} else {
PADDLE_THROW("Not Implemented");
}
lods.emplace_back(dst);
}
return lods;
}
void LoDTensor::MergeLoDTensor(
const std::vector<const LoDTensor *> &lod_tensors, platform::Place place) {
PADDLE_ENFORCE(platform::is_cpu_place(place));
PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type();
for (auto *lod : lod_tensors) {
PADDLE_ENFORCE(new_dim == lod->dims());
PADDLE_ENFORCE(new_type == lod->type());
PADDLE_ENFORCE(platform::is_cpu_place(lod->place()));
}
new_dim[0] *= lod_tensors.size();
Resize(new_dim);
auto *dst_ptr = reinterpret_cast<uint8_t *>(mutable_data(place, new_type));
for (auto *src : lod_tensors) {
auto size = src->numel() * SizeOfType(src->type());
memory::Copy(boost::get<platform::CPUPlace>(place), dst_ptr,
boost::get<platform::CPUPlace>(src->place()),
src->data<void>(), size);
dst_ptr += size;
}
}
} // namespace framework
} // namespace paddle
......@@ -58,6 +58,7 @@ using Vector = thrust::host_vector<
using LoD = std::vector<Vector<size_t>>;
std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
/*
* Slice levels from a LoD.
......@@ -144,6 +145,12 @@ class LoDTensor : public Tensor {
*/
void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end);
std::vector<LoDTensor> SplitLoDTensor(
const std::vector<platform::Place> places) const;
void MergeLoDTensor(const std::vector<const LoDTensor*>& lod_tensors,
platform::Place place);
private:
LoD lod_;
};
......
......@@ -12,13 +12,16 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
namespace pd = paddle::framework;
namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
......@@ -215,7 +218,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
......@@ -252,7 +255,6 @@ TEST(OperatorRegistrar, CPU) {
op->Run(scope, cpu_place);
}
#ifdef PADDLE_WITH_CUDA
TEST(OperatorRegistrar, CUDA) {
paddle::framework::proto::OpDesc op_desc;
paddle::platform::CUDAPlace cuda_place(0);
......@@ -263,4 +265,127 @@ TEST(OperatorRegistrar, CUDA) {
op->Run(scope, cuda_place);
}
#endif
static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext;
namespace paddle {
namespace framework {
class OpWithMultiKernelTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};
template <typename DeviceContext, typename T>
class OpMultiKernelTest : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const;
};
template <typename T>
class OpMultiKernelTest<CPUDeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
++op_test_value;
}
};
template <typename T>
class OpMultiKernelTest<CUDADeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
--op_test_value;
}
};
template <typename DeviceContext, typename T>
class OpMultiKernelTest2 : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const;
};
template <typename T>
class OpMultiKernelTest2<CPUDeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
op_test_value += 10;
}
};
template <typename T>
class OpMultiKernelTest2<CUDADeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
op_test_value -= 10;
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(op_with_multi_kernel,
paddle::framework::OpWithMultiKernelTest,
paddle::framework::OpKernelTestMaker);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CPU, paddle::platform::CPUPlace,
paddle::framework::OpMultiKernelTest<CPUDeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, MKLDNN, paddle::platform::CPUPlace,
paddle::framework::OpMultiKernelTest2<CPUDeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CUDA, paddle::platform::CUDAPlace,
paddle::framework::OpMultiKernelTest<CUDADeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CUDNN, paddle::platform::CUDAPlace,
paddle::framework::OpMultiKernelTest2<CUDADeviceContext, float>);
TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::proto::OpDesc op_desc;
paddle::platform::CUDAPlace cuda_place(0);
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// TODO(qiao) add priority back
// use all available kernels
paddle::framework::UseALL();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
// remove cuda kernels
paddle::framework::UseCPU();
op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -20);
// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -30);
// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -40);
}
......@@ -11,13 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/device_data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h"
......@@ -25,6 +25,64 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
void UseCPU() {
kKernelPriority.clear();
/*Plain CPU*/
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kPlain);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
void UseMKLDNN() {
UseCPU();
#if PADDLE_WITH_MKLML
{
/*MKLDNN Kernel*/
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
#endif
}
void UseCUDA() {
UseMKLDNN();
#if PADDLE_WITH_CUDA
/*Plain GPU*/
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
#endif
}
void UseCUDNN() {
UseCUDA();
#if PADDLE_WITH_CUDA
if (platform::dynload::HasCUDNN()) {
/*CUDNN Kernel*/
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
#endif
}
void UseALL() {
UseCPU();
UseMKLDNN();
UseCUDA();
UseCUDNN();
}
static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
return DDim({-1});
}
}
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
......@@ -57,7 +115,7 @@ const std::vector<std::string>& OperatorBase::Outputs(
return it->second;
}
std::string OperatorBase::DebugString() const {
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{";
for (auto it = inputs_.begin(); it != inputs_.end();) {
......@@ -65,6 +123,9 @@ std::string OperatorBase::DebugString() const {
ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (scope) {
ss << "(" << GetDims(*scope, input.second[i]) << ")";
}
if (i != input.second.size() - 1) {
ss << ", ";
}
......@@ -81,6 +142,9 @@ std::string OperatorBase::DebugString() const {
ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (scope) {
ss << "(" << GetDims(*scope, output.second[i]) << ")";
}
if (i != output.second.size() - 1) {
ss << ", ";
}
......@@ -178,6 +242,10 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}
static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
......@@ -185,7 +253,8 @@ static const Tensor* GetTensorFromVar(const Variable* var) {
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
}
return t;
}
......@@ -197,7 +266,8 @@ static Tensor* GetMutableTensorFromVar(Variable* var) {
} else if (var->IsType<SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
}
return t;
}
......@@ -359,7 +429,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name());
}
}
......@@ -370,7 +441,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name());
}
}
......@@ -384,24 +456,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
const platform::DeviceContext* GetDeviceContext(
framework::KernelTypePair& kernel_pair) {
auto& actual_kernel_key = kernel_pair.first;
auto& expected_kernel_key = kernel_pair.second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(actual_kernel_key.place_) &&
platform::is_cpu_place(expected_kernel_key.place_)) {
return pool.Get(actual_kernel_key.place_);
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
platform::is_gpu_place(expected_kernel_key.place_)) {
return pool.Get(expected_kernel_key.place_);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......@@ -417,71 +471,43 @@ void OperatorWithKernel::Run(const Scope& scope,
"There are no kernels which are registered in the %s operator.", type_);
}
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
auto actual_kernel_key = GetActualKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}
if (actual_kernel_key == expected_kernel_key) {
PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
"Currently, model parallelism is only supported between "
"CPU and other devices. For example, multi-GPU model "
"parallelism will failed.");
} else {
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
const DataTransformFn* trans_fun =
DataTransformFnMap::Instance().GetNullable(kernel_pair);
if (trans_fun) {
auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
// filter vars that has been transformed
std::vector<std::string> need_trans;
for (auto var_name : input_vars) {
auto var_name_trans =
var_name + framework::KernelTypeToString(expected_kernel_key);
if (!scope.FindVar(var_name_trans)) {
const_cast<Scope&>(scope).Var(var_name_trans);
need_trans.push_back(var_name);
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
Scope& new_scope = scope.NewScope();
for (auto& var_name_item : this->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope.FindVar(var_name);
if (var && VarIsTensor(var)) {
auto* tensor_in = GetTensorFromVar(var);
if (tensor_in->IsInitialized()) {
auto kernel_type_for_var = this->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key);
if (kernel_type_for_var != expected_kernel_key) {
auto out_var_names = OutputVars(true);
if (std::find(out_var_names.begin(), out_var_names.end(),
var_name) != out_var_names.end()) {
PADDLE_THROW(
"var %s is both input and output, "
"does not support transform",
var_name);
}
VLOG(3) << "need to do transform for var " << var_name;
auto* trans_var = new_scope.Var(var_name);
auto* out = DataTransform(expected_kernel_key, kernel_type_for_var,
*tensor_in);
CopyVariableWithTensor(*var, *out, *trans_var);
}
}
}
if (!need_trans.empty()) {
auto trans_dev_ctx = GetDeviceContext(kernel_pair);
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : need_trans) {
(*trans_fun)(trans_dev_ctx, kernel_pair, *(scope.FindVar(var_name)),
scope.FindVar(var_name + framework::KernelTypeToString(
expected_kernel_key)));
}
// Wait for data transform finishing
trans_dev_ctx->Wait();
}
}
}
kernel_iter->second->Compute(ctx);
}
OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelMap& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
kernel_iter->second->Compute(ExecutionContext(*this, new_scope, *dev_ctx));
}
proto::DataType OperatorWithKernel::IndicateDataType(
......@@ -513,5 +539,16 @@ proto::DataType OperatorWithKernel::IndicateDataType(
return static_cast<proto::DataType>(data_type);
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const {
return OpKernelType(expected_kernel_type.data_type_, tensor.place());
}
} // namespace framework
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
......@@ -52,10 +53,33 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO";
// define some kernel hint
const std::string kUseCPU = "use_cpu";
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";
// define some kernel priority
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
/**
* @brief Use cpu kernel only
*/
void UseCPU();
/**
* @brief Perfer MKLDNN kernel than Plain CPU kernel
*/
void UseMKLDNN();
/**
* @brief Perfer CUDA kernel than Plain CPU kernel
*/
void UseCUDA();
/**
* @brief Perfer cudnn kernel than Plain CUDA kernel
*/
void UseCUDNN();
/**
* @brief Use all available kernels
*/
void UseALL();
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
......@@ -84,7 +108,10 @@ class OperatorBase {
return boost::get<T>(attrs_.at(name));
}
virtual std::string DebugString() const;
/// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
......@@ -381,9 +408,10 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
......
......@@ -114,7 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
......
......@@ -109,6 +109,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
Rename(origin_name, var_name);
return var_name;
}
Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) return it->second;
......
......@@ -75,9 +75,9 @@ class Scope {
// Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const;
private:
Variable* FindVarLocally(const std::string& name) const;
private:
// Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {}
......
......@@ -55,6 +55,10 @@ class Tensor {
template <typename T>
inline const T* data() const;
inline bool IsInitialized() const;
inline void switch_place(platform::Place new_place);
/**
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
......@@ -200,6 +204,15 @@ class Tensor {
size_t offset_;
};
inline void Tensor::switch_place(platform::Place new_place) {
if (holder_->place() == new_place) {
return;
}
// TODO(tonyyang-svail): do memcpy here.
PADDLE_THROW("Not Implemented");
}
} // namespace framework
} // namespace paddle
......
......@@ -84,6 +84,8 @@ inline const T* Tensor::data() const {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
inline bool Tensor::IsInitialized() const { return holder_ != nullptr; }
template <typename T>
inline T* Tensor::data() {
check_memory_size();
......
......@@ -32,6 +32,8 @@ class Variable {
return *static_cast<const T*>(holder_->Ptr());
}
bool IsInitialized() const { return holder_ != nullptr; }
template <typename T>
T* GetMutable() {
if (!IsType<T>()) {
......
......@@ -38,23 +38,16 @@ void InferenceEngine::LoadInferenceModel(
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
// PicklingTools cannot parse the vector of strings correctly.
#else
// program_desc_str
// the inference.model is stored by following python codes:
// inference_program = fluid.io.get_inference_program(predict)
// model_filename = "recognize_digits_mlp.inference.model/inference.model"
// with open(model_filename, "w") as f:
// program_str = inference_program.desc.serialize_to_string()
// f.write(struct.pack('q', len(program_str)))
// f.write(program_str)
std::string model_filename = dirname + "/inference.model";
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream fs(model_filename, std::ios_base::binary);
int64_t size = 0;
fs.read(reinterpret_cast<char*>(&size), sizeof(int64_t));
LOG(INFO) << "program_desc_str's size: " << size;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
program_desc_str.resize(size);
fs.read(&program_desc_str[0], size);
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
#endif
program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);
......
......@@ -152,6 +152,7 @@ op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
......
......@@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
......@@ -47,8 +47,7 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
......@@ -56,8 +55,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(*place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
if (platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate->data<T>();
param_out.device(*place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else {
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
param_out.device(*place) =
param -
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
......
......@@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
......@@ -306,7 +306,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
......
......@@ -55,10 +55,10 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
platform::CPUPlace());
}
};
......
......@@ -145,6 +145,7 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
context.Attr<std::vector<int>>("excluded_chunk_types").end());
auto* inference = context.Input<LoDTensor>("Inference");
auto place = inference->place();
auto* label = context.Input<LoDTensor>("Label");
auto* precision = context.Output<Tensor>("Precision");
auto* recall = context.Output<Tensor>("Recall");
......@@ -155,15 +156,15 @@ class ChunkEvalKernel : public framework::OpKernel<T> {
const int64_t* inference_data = inference->data<int64_t>();
const int64_t* label_data = label->data<int64_t>();
T* precision_data = precision->mutable_data<T>(context.GetPlace());
T* racall_data = recall->mutable_data<T>(context.GetPlace());
T* f1_data = f1->mutable_data<T>(context.GetPlace());
T* precision_data = precision->mutable_data<T>(place);
T* racall_data = recall->mutable_data<T>(place);
T* f1_data = f1->mutable_data<T>(place);
int64_t* num_infer_chunks_data =
num_infer_chunks->mutable_data<int64_t>(context.GetPlace());
num_infer_chunks->mutable_data<int64_t>(place);
int64_t* num_label_chunks_data =
num_label_chunks->mutable_data<int64_t>(context.GetPlace());
num_label_chunks->mutable_data<int64_t>(place);
int64_t* num_correct_chunks_data =
num_correct_chunks->mutable_data<int64_t>(context.GetPlace());
num_correct_chunks->mutable_data<int64_t>(place);
*num_infer_chunks_data = 0;
*num_label_chunks_data = 0;
*num_correct_chunks_data = 0;
......
......@@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
......
......@@ -315,10 +315,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(conv2d, CUDNN, paddle::platform::CUDAPlace,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
// TODO(dzhwinter) : below register should be removed
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
......
......@@ -120,17 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& actual_kernel_type) const override {
return framework::OpKernelType(actual_kernel_type.data_type_,
platform::CPUPlace());
platform::CPUPlace());
}
};
} // namespace operators
......
......@@ -28,9 +28,6 @@ template <typename DeviceContext, typename T>
class CRFDecodingOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"The crf_decoding operator can only run on CPU.");
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
auto* transition_weights = ctx.Input<Tensor>("Transition");
auto* label = ctx.Input<LoDTensor>("Label");
......
......@@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......@@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......@@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
......@@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......@@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
......@@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
......
......@@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......@@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......
......@@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......@@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......
......@@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......@@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......
......@@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......@@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......
......@@ -56,11 +56,11 @@ void NetOp::CompleteAddOp(bool calc) {
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
}
std::string NetOp::DebugString() const {
std::string NetOp::DebugStringEx(const framework::Scope* scope) const {
std::ostringstream os;
os << OperatorBase::DebugString() << std::endl;
os << OperatorBase::DebugStringEx(scope) << std::endl;
for (auto& op : ops_) {
std::istringstream is(op->DebugString());
std::istringstream is(op->DebugStringEx(scope));
for (std::string line; std::getline(is, line);) {
os << " " << line << std::endl;
}
......
......@@ -106,7 +106,8 @@ class NetOp : public framework::OperatorBase {
void CompleteAddOp(bool calculate = true);
std::string DebugString() const override;
std::string DebugStringEx(
const framework::Scope* scope = nullptr) const override;
bool IsNetOp() const override;
std::vector<std::string> OutputVars(bool has_intermediate) const override;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/framework/executor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace operators {
static constexpr char kInputs[] = "inputs";
static constexpr char kParameters[] = "parameters";
static constexpr char kPlaces[] = "places";
static constexpr char kOutputs[] = "outputs";
static constexpr char kParallelScopes[] = "parallel_scopes";
static constexpr char kParallelBlock[] = "sub_block";
// using ParallelScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor;
using OperatorBase = framework::OperatorBase;
void SplitTensorAndMoveTensorToScopes(
const framework::Scope &scope,
const std::vector<framework::Scope *> &sub_scopes,
const std::vector<platform::Place> &places,
const std::vector<std::string> &names) {
for (auto &argu : names) {
auto *var = scope.FindVar(argu);
const auto &tensor = var->Get<LoDTensor>();
auto lod_tensors = tensor.SplitLoDTensor(places);
for (auto &lod : lod_tensors) {
VLOG(3) << lod.dims();
}
for (size_t i = 0; i < sub_scopes.size(); ++i) {
*sub_scopes[i]->Var(argu)->GetMutable<LoDTensor>() = lod_tensors[i];
}
}
}
class ParallelDoOp : public framework::OperatorBase {
public:
ParallelDoOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program();
// TODO(tonyyang-svail): get places from input
std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
places.emplace_back(platform::CPUPlace());
auto &sub_scopes = *scope.FindVar(Output(kParallelScopes))
->GetMutable<std::vector<framework::Scope *>>();
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
sub_scopes.push_back(&scope.NewScope());
}
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(kInputs));
std::vector<std::future<void>> workers;
workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx;
auto &place = places[place_idx];
auto *cur_scope = sub_scopes[place_idx];
// copy parameter
// some version of boost lacks != for boost::variant
if (!(dev_ctx.GetPlace() == place)) {
PADDLE_THROW("Not Implemented");
}
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/);
}));
}
for (auto &worker : workers) {
worker.wait();
}
// merge output
for (auto &o_name : Outputs(kOutputs)) {
std::vector<const framework::LoDTensor *> lod_tensors;
lod_tensors.reserve(sub_scopes.size());
for (auto *sub_scope : sub_scopes) {
lod_tensors.emplace_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>());
}
auto *lod_tensor_to_be_merged =
scope.FindVar(o_name)->GetMutable<LoDTensor>();
lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace());
}
}
};
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ParallelDoOpProtoMaker(OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInputs, "").AsDuplicable();
AddInput(kParameters, "").AsDuplicable();
AddInput(kPlaces, "");
AddOutput(kOutputs, "").AsDuplicable();
AddOutput(kParallelScopes, "");
AddAttr<framework::BlockDesc *>(kParallelBlock, "");
AddComment(R"DOC(
ParallelDo Operator.
)DOC");
}
};
class ParallelDoGradOp : public OperatorBase {
public:
ParallelDoGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override {
// // get device context from pool
// platform::DeviceContextPool &pool =
// platform::DeviceContextPool::Instance();
// auto &dev_ctx = *pool.Get(place);
auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program();
auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
// TODO(tonyyang-svail): get places from input
std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
places.emplace_back(platform::CPUPlace());
// feed output@grad
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(framework::GradVarName(kOutputs)));
for (auto &s : Inputs(framework::GradVarName(kOutputs))) {
VLOG(3) << s;
VLOG(3) << scope.FindVar(s)->Get<LoDTensor>();
for (auto *sub_scope : sub_scopes) {
VLOG(3) << sub_scope->FindVar(s)->Get<LoDTensor>();
}
}
// exe run
std::vector<std::future<void>> workers;
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx;
auto &place = places[place_idx];
auto *cur_scope = sub_scopes[place_idx];
// execute
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/);
}));
}
for (auto &worker : workers) {
worker.wait();
}
// merge grad
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
VLOG(3) << s;
auto &t = sub_scopes[0]->FindVar(s)->Get<LoDTensor>();
VLOG(3) << t;
std::string s_buf = s + "@BUF";
auto *t_buf = sub_scopes[0]->Var(s_buf)->GetMutable<LoDTensor>();
for (size_t place_idx = 1; place_idx < places.size(); ++place_idx) {
auto &tt = sub_scopes[place_idx]->FindVar(s)->Get<LoDTensor>();
VLOG(3) << place_idx;
VLOG(3) << tt;
framework::CopyFrom(tt, places[0], t_buf);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {s, s_buf}}}, {{"Out", {s}}},
framework::AttributeMap{});
sum_op->Run(*sub_scopes[0], place);
}
VLOG(3) << t;
framework::CopyFrom(t, place, scope.FindVar(s)->GetMutable<LoDTensor>());
}
}
};
class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<framework::OpDesc> Apply() const {
auto *grad = new framework::OpDesc();
grad->SetType("parallel_do_grad");
for (auto &input_param : this->InputNames()) {
VLOG(3) << input_param;
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param, false));
}
for (auto &output_param : this->OutputNames()) {
if (output_param == kParallelScopes) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param));
} else {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
}
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]);
return std::unique_ptr<framework::OpDesc>(grad);
}
};
class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
std::vector<std::string> input{kParameters, kInputs};
std::vector<std::string> output{kOutputs};
for (auto &s : input) {
PADDLE_ENFORCE(ctx->HasInputs(s));
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
"Cannot find the gradient variable %s",
framework::GradVarName(s));
}
for (auto &s : output) {
PADDLE_ENFORCE(ctx->HasInputs(s));
}
for (auto &s : input) {
ctx->SetOutputsDim(framework::GradVarName(s), ctx->GetInputsDim(s));
}
if (ctx->HasInputs(kParameters)) {
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)));
ctx->SetOutputsDim(framework::GradVarName(kParameters),
ctx->GetInputsDim(kParameters));
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp,
paddle::operators::ParallelDoOpProtoMaker,
paddle::operators::ParallelDoGradOpDescMaker);
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp,
paddle::operators::ParallelDoGradOpShapeInference);
......@@ -69,7 +69,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -90,7 +90,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -85,7 +85,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
......
......@@ -80,7 +80,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
......
......@@ -68,7 +68,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -89,7 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -49,7 +49,7 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......
......@@ -107,7 +107,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -48,7 +48,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......@@ -69,7 +69,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
......@@ -118,7 +118,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
......@@ -159,7 +159,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -53,7 +53,7 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
......
......@@ -63,7 +63,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -71,7 +71,7 @@ int OutputSize(int input_size, int ksize, int padding, int stride) {
class UnpoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -110,7 +110,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetActualKernelType(
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -21,10 +21,16 @@ ELSE()
set(GPU_CTX_DEPS)
ENDIF()
IF(WITH_MKLDNN)
set(MKLDNN_CTX_DEPS mkldnn)
ELSE()
set(MKLDNN_CTX_DEPS)
ENDIF()
# memcpy deoends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
......
......@@ -168,5 +168,69 @@ cudaStream_t CUDADeviceContext::stream() const { return stream_; }
#endif
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), ready_(false) {
stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
}
template <typename T>
void MKLDNNDeviceContext::AddElement(const std::string& op_key,
const T& value) {
if (GetElement<T>(op_key)) {
return;
}
GetElementPool<T>().emplace(op_key, std::move(value));
}
template <typename T>
const T& MKLDNNDeviceContext::GetElement(const std::string& op_key) const {
auto it = GetElementPool<T>().find(op_key);
return it == GetElementPool<T>().end() ? nullptr : it->second;
}
template <>
const std::unordered_map<const std::string, const MKLDNNMemoryPtr,
std::hash<std::string>>&
MKLDNNDeviceContext::GetElementPool<MKLDNNMemoryPtr>() const {
return memory_pool_;
}
template <>
const std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
std::hash<std::string>>&
MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitivePtr>() const {
return primitive_pool_;
}
template <>
const std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
std::hash<std::string>>&
MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitiveDescPtr>() const {
return primitive_desc_pool_;
}
void MKLDNNDeviceContext::Execute(bool block) {
if (pipeline_.empty()) {
return;
}
ResetStream();
stream_->submit(pipeline_).wait(block);
ready_ = false;
pipeline_.clear();
}
void MKLDNNDeviceContext::ResetStream() {
if (ready_) {
return;
}
// TODO(TJ): change me when mkldnn have specific method to reset this state
stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
ready_ = true;
}
#endif
} // namespace platform
} // namespace paddle
......@@ -21,6 +21,10 @@ limitations under the License. */
#define EIGEN_USE_GPU
#endif
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/platform/mkldnn_helper.h"
#endif
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -105,6 +109,54 @@ struct DefaultDeviceContextType<platform::CUDAPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Add new element: memory, primitive or primitive desc */
template <typename T>
void AddElement(const std::string& op_key, const T& value);
/* \brief Get existed element: memory, primitive or primitive desc */
template <typename T>
const T& GetElement(const std::string& op_key) const;
/* \brief Get element pool: memory, primitive or primitive desc pool */
template <typename T>
const std::unordered_map<const std::string, const T, std::hash<std::string>>&
GetElementPool() const;
/* \brief Get the active engine */
const MKLDNNEngine& engine() const { return *engine_; }
/* \brief Submit primitive to pipeline */
void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }
/*! \brief Execute all submitted primitives in pipeline */
void Execute(bool block = true);
protected:
/*! \brief Reset the stream to prepare next exectue */
void ResetStream();
private:
std::unordered_map<const std::string, const MKLDNNMemoryPtr,
std::hash<std::string>>
memory_pool_;
std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
std::hash<std::string>>
primitive_pool_;
std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
std::hash<std::string>>
primitive_desc_pool_;
std::vector<MKLDNNPrimitive> pipeline_;
MKLDNNStreamPtr stream_;
MKLDNNEnginePtr engine_;
bool ready_;
};
#endif
/*! \brief device context pool singleton */
class DeviceContextPool {
public:
......
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkldnn.hpp>
namespace paddle {
namespace platform {
using MKLDNNStream = mkldnn::stream;
using MKLDNNEngine = mkldnn::engine;
using MKLDNNMemory = mkldnn::memory;
using MKLDNNPrimitive = mkldnn::primitive;
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
} // namespace platform
} // namespace paddle
......@@ -51,6 +51,18 @@ bool places_are_same_class(const Place &p1, const Place &p2) {
return p1.which() == p2.which();
}
bool is_same_place(const Place &p1, const Place &p2) {
if (places_are_same_class(p1, p2)) {
if (is_cpu_place(p1)) {
return true;
} else {
return boost::get<CUDAPlace>(p1) == boost::get<CUDAPlace>(p2);
}
} else {
return false;
}
}
std::ostream &operator<<(std::ostream &os, const Place &p) {
detail::PlacePrinter printer(os);
boost::apply_visitor(printer, p);
......
......@@ -61,6 +61,7 @@ const CPUPlace default_cpu();
bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &);
bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &);
std::ostream &operator<<(std::ostream &, const Place &);
......
......@@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/profiler.h"
#include <iomanip>
#include <map>
#include "glog/logging.h"
namespace paddle {
namespace platform {
// The profiler state, the initial value is ProfilerState::kDisabled
static ProfilerState g_state = ProfilerState::kDisabled;
// To record which timer the profiler used, CUDA or CPU.
static std::string g_profiler_place = "";
// The thread local event list only can be accessed by the specific thread
// The thread index of each thread
static thread_local int32_t g_thread_id;
......@@ -43,10 +48,7 @@ inline uint64_t GetTimeInNsec() {
Event::Event(EventKind kind, std::string name, uint32_t thread_id,
DeviceContext* dev_ctx)
: kind_(kind),
name_(std::move(name)),
thread_id_(thread_id),
has_cuda_(false) {
: kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) {
#ifdef PADDLE_WITH_CUDA
auto* cuda_dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctx);
if (cuda_dev_ctx) {
......@@ -72,11 +74,11 @@ std::string Event::kind() const {
PADDLE_THROW("Unknown EventKind.");
}
double Event::CpuElapsedUs(const Event& e) const {
return (e.cpu_ns_ - cpu_ns_) / (1000.0);
double Event::CpuElapsedMs(const Event& e) const {
return (e.cpu_ns_ - cpu_ns_) / (1000000.0);
}
double Event::CudaElapsedUs(const Event& e) const {
double Event::CudaElapsedMs(const Event& e) const {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(e.has_cuda() && has_cuda());
PADDLE_ENFORCE(e.device() == device());
......@@ -84,7 +86,7 @@ double Event::CudaElapsedUs(const Event& e) const {
PADDLE_ENFORCE(cudaEventSynchronize(e.event()));
float ms;
PADDLE_ENFORCE(cudaEventElapsedTime(&ms, event_, e.event()));
return ms * 1000.0;
return ms;
#else
PADDLE_THROW("CUDA is not enabled");
#endif
......@@ -113,21 +115,27 @@ inline EventList& GetEventList() {
}
void Mark(const std::string& name, DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kMark, std::move(name), g_thread_id,
dev_ctx);
GetEventList().Record(EventKind::kMark, name, g_thread_id, dev_ctx);
}
void PushEvent(const std::string& name, DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPushRange, name, g_thread_id, dev_ctx);
}
void PopEvent(const std::string& name, DeviceContext* dev_ctx) {
GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx);
}
RecordEvent::RecordEvent(const std::string& name, DeviceContext* dev_ctx) {
if (g_state == ProfilerState::kDisabled) return;
dev_ctx_ = dev_ctx;
GetEventList().Record(EventKind::kPushRange, std::move(name), g_thread_id,
dev_ctx_);
name_ = name;
PushEvent(name_, dev_ctx_);
}
RecordEvent::~RecordEvent() {
if (g_state == ProfilerState::kDisabled) return;
GetEventList().Record(EventKind::kPopRange, std::string(), g_thread_id,
dev_ctx_);
PopEvent(name_, dev_ctx_);
}
void EnableProfiler(ProfilerState state) {
......@@ -138,6 +146,7 @@ void EnableProfiler(ProfilerState state) {
"The profiling state should be disabled when calling ",
"EnableProfiler.");
g_state = state;
g_profiler_place = (g_state == ProfilerState::kCUDA) ? "CUDA" : "CPU";
#ifdef PADDLE_WITH_CUDA
if (g_state == ProfilerState::kCUDA) {
// Generate some dummy evenets first to reduce the startup overhead.
......@@ -169,5 +178,152 @@ std::vector<std::vector<Event>> DisableProfiler() {
return result;
}
void ParseEvents(std::vector<std::vector<Event>>& events,
EventSortingKey sorted_by) {
if (g_profiler_place == "") return;
std::string sorted_domain;
std::function<bool(EventItem&, EventItem&)> sorted_func;
switch (sorted_by) {
case EventSortingKey::kCalls:
sorted_domain = "number of calls";
sorted_func = [](EventItem& a, EventItem& b) {
return a.calls > b.calls;
};
break;
case EventSortingKey::kTotal:
sorted_domain = "total time";
sorted_func = [](EventItem& a, EventItem& b) {
return a.total_time > b.total_time;
};
break;
case EventSortingKey::kMin:
sorted_domain = "minimum time";
sorted_func = [](EventItem& a, EventItem& b) {
return a.min_time > b.min_time;
};
break;
case EventSortingKey::kMax:
sorted_domain = "maximum time";
sorted_func = [](EventItem& a, EventItem& b) {
return a.max_time > b.max_time;
};
break;
case EventSortingKey::kAve:
sorted_domain = "average time";
sorted_func = [](EventItem& a, EventItem& b) {
return a.ave_time > b.ave_time;
};
break;
default:
sorted_domain = "event end time";
}
std::vector<std::vector<EventItem>> events_table;
size_t max_name_width = 0;
for (size_t i = 0; i < events.size(); i++) {
std::list<Event> pushed_events;
std::vector<EventItem> event_items;
std::unordered_map<std::string, int> event_idx;
for (size_t j = 0; j < events[i].size(); j++) {
if (events[i][j].kind() == "push") {
pushed_events.push_back(events[i][j]);
} else if (events[i][j].kind() == "pop") {
std::list<Event>::reverse_iterator rit = pushed_events.rbegin();
while (rit != pushed_events.rend() &&
rit->name() != events[i][j].name()) {
++rit;
}
if (rit != pushed_events.rend()) {
double event_time = (g_profiler_place == "CUDA")
? rit->CudaElapsedMs(events[i][j])
: rit->CpuElapsedMs(events[i][j]);
std::string event_name =
"thread" + std::to_string(rit->thread_id()) + "::" + rit->name();
max_name_width = std::max(max_name_width, event_name.size());
if (event_idx.find(event_name) == event_idx.end()) {
event_idx[event_name] = event_items.size();
EventItem event_item = {event_name, 1, event_time,
event_time, event_time, event_time};
event_items.push_back(event_item);
} else {
int index = event_idx[event_name];
event_items[index].calls += 1;
// total time
event_items[index].total_time += event_time;
// min time
event_items[index].min_time =
std::min(event_time, event_items[index].min_time);
// max time
event_items[index].max_time =
std::max(event_time, event_items[index].max_time);
}
// remove the push marker from the list
pushed_events.erase((++rit).base());
} else {
LOG(WARNING) << "Cannot find the push marker of event \'"
<< events[i][j].name()
<< "\', which will be ignored in profiling report.";
}
}
}
// average time
for (auto& item : event_items) {
item.ave_time = item.total_time / item.calls;
}
// sort
if (sorted_by != EventSortingKey::kDefault) {
std::sort(event_items.begin(), event_items.end(), sorted_func);
}
events_table.push_back(event_items);
// log warning if there are events with `push` but without `pop`
std::list<Event>::reverse_iterator rit = pushed_events.rbegin();
while (rit != pushed_events.rend()) {
LOG(WARNING) << "Cannot find the pop marker of event \'" << rit->name()
<< "\', which will be ignored in profiling report.";
++rit;
}
}
// Print report
PrintProfilingReport(events_table, sorted_domain, max_name_width + 4, 12);
}
void PrintProfilingReport(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width) {
// Output header information
std::cout << "\n------------------------->"
<< " Profiling Report "
<< "<-------------------------\n\n";
std::cout << "Place: " << g_profiler_place << std::endl;
std::cout << "Time unit: ms" << std::endl;
std::cout << "Sorted by " << sorted_domain
<< " in descending order in the same thread\n\n";
// Output events table
std::cout.setf(std::ios::left);
std::cout << std::setw(name_width) << "Event" << std::setw(data_width)
<< "Calls" << std::setw(data_width) << "Total"
<< std::setw(data_width) << "Min." << std::setw(data_width)
<< "Max." << std::setw(data_width) << "Ave." << std::endl;
for (size_t i = 0; i < events_table.size(); ++i) {
for (size_t j = 0; j < events_table[i].size(); ++j) {
EventItem& event_item = events_table[i][j];
std::cout << std::setw(name_width) << event_item.name
<< std::setw(data_width) << event_item.calls
<< std::setw(data_width) << event_item.total_time
<< std::setw(data_width) << event_item.min_time
<< std::setw(data_width) << event_item.max_time
<< std::setw(data_width) << event_item.ave_time << std::endl;
}
}
std::cout << std::endl;
}
} // namespace platform
} // namespace paddle
......@@ -33,6 +33,7 @@ class Event {
std::string kind() const;
std::string name() const { return name_; }
uint32_t thread_id() const { return thread_id_; }
bool has_cuda() const { return has_cuda_; }
#ifdef PADDLE_WITH_CUDA
......@@ -40,8 +41,8 @@ class Event {
int device() const { return device_; }
#endif
double CpuElapsedUs(const Event& e) const;
double CudaElapsedUs(const Event& e) const;
double CpuElapsedMs(const Event& e) const;
double CudaElapsedMs(const Event& e) const;
private:
EventKind kind_;
......@@ -94,6 +95,10 @@ enum ProfilerState {
void Mark(const std::string& name, DeviceContext* dev_ctx);
void PushEvent(const std::string& name, DeviceContext* dev_ctx);
void PopEvent(const std::string& name, DeviceContext* dev_ctx);
struct RecordEvent {
explicit RecordEvent(const std::string& name, DeviceContext* dev_ctx);
......@@ -101,6 +106,8 @@ struct RecordEvent {
// The device context is used by Event to get the current cuda stream.
DeviceContext* dev_ctx_;
// Event name
std::string name_;
};
// Enable the profiling function.
......@@ -110,5 +117,26 @@ void EnableProfiler(ProfilerState state);
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
std::vector<std::vector<Event>> DisableProfiler();
// The information of each event given in the profiling report
struct EventItem {
std::string name;
int calls;
double total_time;
double min_time;
double max_time;
double ave_time;
};
// Candidate keys to sort the profiling report
enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve };
// Parse the event list and output the profiling report
void ParseEvents(std::vector<std::vector<Event>>&,
EventSortingKey sorted_by = EventSortingKey::kDefault);
// Print results
void PrintProfilingReport(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width);
} // namespace platform
} // namespace paddle
......@@ -26,7 +26,7 @@ TEST(Event, CpuElapsedTime) {
counter++;
}
Event stop_event(EventKind::kPopRange, "test", 0, nullptr);
EXPECT_GT(start_event.CpuElapsedUs(stop_event), 0);
EXPECT_GT(start_event.CpuElapsedMs(stop_event), 0);
}
#ifdef PADDLE_WITH_CUDA
......@@ -45,7 +45,7 @@ TEST(Event, CudaElapsedTime) {
counter++;
}
Event stop_event(EventKind::kPopRange, "test", 0, dev_ctx);
EXPECT_GT(start_event.CudaElapsedUs(stop_event), 0);
EXPECT_GT(start_event.CudaElapsedMs(stop_event), 0);
}
#endif
......@@ -55,6 +55,7 @@ TEST(RecordEvent, RecordEvent) {
using paddle::platform::EventKind;
using paddle::platform::RecordEvent;
using paddle::platform::ProfilerState;
using paddle::platform::EventSortingKey;
ProfilerState state = ProfilerState::kCPU;
DeviceContext* dev_ctx = nullptr;
......@@ -67,13 +68,45 @@ TEST(RecordEvent, RecordEvent) {
#endif
EnableProfiler(state);
/* Usage 1:
* PushEvent(evt_name, dev_ctx);
* ...
* code to be analyzed
* ...
* PopEvent(evt_name, dev_ctx);
*/
for (int loop = 0; loop < 3; ++loop) {
for (int i = 1; i < 5; ++i) {
std::string name = "op_" + std::to_string(i);
PushEvent(name, dev_ctx);
int counter = 1;
while (counter != i * 1000) counter++;
PopEvent(name, dev_ctx);
}
}
/* Usage 2:
* {
* RecordEvent record_event(name, dev_ctx);
* ...
* code to be analyzed
* ...
* }
*/
for (int i = 1; i < 5; ++i) {
std::string name = "op_" + std::to_string(i);
std::string name = "evs_op_" + std::to_string(i);
RecordEvent record_event(name, dev_ctx);
int counter = 1;
while (counter != i * 1000) counter++;
}
// Bad Usage:
PushEvent("event_without_pop", dev_ctx);
PopEvent("event_without_push", dev_ctx);
std::vector<std::vector<Event>> events = paddle::platform::DisableProfiler();
// Will remove parsing-related code from test later
ParseEvents(events, EventSortingKey::kTotal);
int cuda_startup_count = 0;
int start_profiler_count = 0;
int stop_profiler_count = 0;
......@@ -85,9 +118,9 @@ TEST(RecordEvent, RecordEvent) {
if (events[i][j].name() == "push") {
EXPECT_EQ(events[i][j + 1].name(), "pop");
#ifdef PADDLE_WITH_CUDA
EXPECT_GT(events[i][j].CudaElapsedUs(events[i][j + 1]), 0);
EXPECT_GT(events[i][j].CudaElapsedMs(events[i][j + 1]), 0);
#else
EXPECT_GT(events[i][j].CpuElapsedUs(events[i][j + 1]), 0);
EXPECT_GT(events[i][j].CpuElapsedMs(events[i][j + 1]), 0);
#endif
}
}
......
......@@ -23,11 +23,6 @@ void BindConstValue(pybind11::module& m) {
m.def("kTempVarName", [] { return framework::kTempVarName; });
m.def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m.def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
// for kernel_hint key
m.def("kUseCPU", [] { return framework::kUseCPU; });
m.def("kUseCUDNN", [] { return framework::kUseCUDNN; });
m.def("kUseMKLDNN", [] { return framework::kUseMKLDNN; });
}
} // namespace pybind
......
......@@ -430,6 +430,12 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_glog", framework::InitGLOG);
m.def("init_devices", &framework::InitDevices);
m.def("use_cpu", framework::UseCPU);
m.def("use_mkldnn", framework::UseMKLDNN);
m.def("use_cuda", framework::UseCUDA);
m.def("use_cudnn", framework::UseCUDNN);
m.def("use_all", framework::UseALL);
m.def("is_compile_gpu", IsCompileGPU);
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
......
......@@ -193,6 +193,16 @@ EOF
EOF
}
function gen_capi_package() {
if [[ ${WITH_C_API} == "ON" ]]; then
install_prefix="/paddle/build/capi_output"
rm -rf $install_prefix
make DESTDIR="$install_prefix" install
cd $install_prefix/usr/local
ls | egrep -v "^Found.*item$" | xargs tar -cf /paddle/build/paddle.tgz
fi
}
set -xe
cmake_gen ${PYTHON_ABI:-""}
......@@ -200,6 +210,11 @@ run_build
run_test
gen_docs
gen_dockerfile
printf "If you need to install PaddlePaddle in develop docker image,"
printf "please make install or pip install build/python/dist/*.whl.\n"
gen_capi_package
if [[ ${WITH_C_API:-OFF} == "ON" ]]; then
printf "PaddlePaddle C-API libraries was generated on build/paddle.tgz\n"
else
printf "If you need to install PaddlePaddle in develop docker image,"
printf "please make install or pip install build/python/dist/*.whl.\n"
fi
......@@ -29,8 +29,8 @@ if(WITH_MKLML)
endif()
if(WITH_MKLDNN)
list(APPEND MKL_SHARED_LIBS "${MKLDNN_LIB}" "${MKLDNN_LIB}.0")
list(APPEND MKL_DEPENDS mkldnn)
list(APPEND MKL_SHARED_LIBS "${MKLDNN_SHARED_LIB}")
list(APPEND MKL_DEPENDS mkldnn mkldnn_shared_lib)
endif()
if(WITH_GPU)
......
......@@ -58,12 +58,12 @@ def is_compatible_with(x, Type):
class HookAttribute(object):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
......@@ -71,10 +71,10 @@ class HookAttribute(object):
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def __init__(self, type, sparsity_ratio=None):
......@@ -130,10 +130,12 @@ class ParameterAttribute(object):
:param sparse_update: Enable sparse update for this parameter. It will
enable both local and remote sparse update.
:type sparse_update: bool
:param update_hooks: A HookAttribute object.
:type update_hooks: HookAttribute
:param initializer: If not None, it should be a callable object which accepts
a parameter name and returns numpy array for the initial
value of the parameter
:param initializer: callable object
:type initializer: callable object
"""
def __init__(self,
......
......@@ -7,7 +7,7 @@ __all__ = ['append_backward']
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
"""
Traverse all ops in op_descs[begin_idx : end_idx],
Traverse all ops in op_descs[begin_idx : end_idx],
if any op has inputs/outputs named "old_name", rename it as 'new_name'
"""
if begin_idx is None:
......@@ -162,7 +162,7 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((_create_op_desc_("fill_zeros_like", {
"X": [_strip_grad_suffix_(arg)]
}, {"Y": [arg]}, {}), idx))
}, {"Out": [arg]}, {}), idx))
map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))
......@@ -182,7 +182,7 @@ def _append_backward_ops_(target,
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
......@@ -205,6 +205,7 @@ def _append_backward_ops_(target,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......@@ -275,8 +276,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
loss(Variable): The variable generated by cost function.
parameter_list(list): Parameters that need to be updated by optimizer.
If None, it means all parameters need to be updated.
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
contains all variables with `step_gradient=True` from all blocks.
Return:
......
......@@ -17,10 +17,6 @@ TEMP_VAR_NAME = core.kTempVarName()
GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
USE_CPU = core.kUseCPU()
USE_CUDNN = core.kUseMKLDNN()
USE_MKLDNN = core.kUseMKLDNN()
def grad_var_name(var_name):
"""
......@@ -452,7 +448,7 @@ class Operator(object):
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv'
'recv', 'parallel_do'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
......@@ -212,6 +212,11 @@ def save_inference_model(dirname,
"fetch_var_names": fetch_var_names
}, f, -1)
# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
with open(model_file_name + ".dat", "wb") as fp:
fp.write(inference_program.desc.serialize_to_string())
save_params(executor, dirname, main_program)
......
......@@ -6,12 +6,13 @@ import contextlib
from ..registry import autodoc
__all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table',
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor',
'increment', 'array_write', 'create_array', 'less_than', 'array_read',
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock',
'StaticRNN', 'reorder_lod_tensor_by_rank'
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard',
'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While',
'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array',
'array_to_lod_tensor', 'increment', 'array_write', 'create_array',
'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse',
'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank',
'ParallelDo'
]
......@@ -132,29 +133,129 @@ class BlockGuard(object):
return True
class StaticRNNGuard(BlockGuard):
class ParallelDo(object):
"""
StaticRNNGuard class.
ParallelDo class.
StaticRNNGuard class is used to create a StaticRNN block in a program.
ParallelDo class is used to create a ParallelDo.
"""
def __init__(self, places, name=None):
self.helper = LayerHelper("parallel_do", name=name)
self.inputs = []
self.places = places
self.outputs = []
self.status = StaticRNN.BEFORE_RNN_BLOCK
def do(self):
return BlockGuardWithCompletion(self)
def parent_block(self):
prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def read_input(self, var):
self.inputs.append(var)
return var
def write_output(self, var):
self.outputs.append(var)
def get_parameters(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
local_inputs = set()
for op in current_block.ops:
for oname in op.output_names:
for out_var_name in op.output(oname):
local_inputs.add(out_var_name)
for var in self.inputs:
local_inputs.add(var.name)
params = list()
for op in current_block.ops:
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in local_inputs:
params.append(in_var_name)
return [parent_block.var(name) for name in params]
def complete_op(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
self.outputs = [
parent_block.create_var(
name=o.name,
shape=o.shape,
dtype=o.dtype,
lod_level=o.lod_level,
persistable=o.persistable,
stop_gradient=o.stop_gradient) for o in self.outputs
]
inputs = [parent_block.var(i.name) for i in self.inputs]
outputs = [parent_block.var(o.name) for o in self.outputs]
parent_block.append_op(
type='parallel_do',
inputs={
'inputs': inputs,
'parameters': self.get_parameters(),
'places': self.places
},
outputs={'outputs': outputs,
'parallel_scopes': [step_scope]},
attrs={'sub_block': current_block})
class BlockGuardWithCompletion(BlockGuard):
"""
BlockGuardWithCompletion class.
BlockGuardWithCompletion class is used to create an op with a block in a program.
"""
def __init__(self, rnn):
if not isinstance(rnn, StaticRNN):
raise TypeError("StaticRNNGuard takes a StaticRNN")
super(StaticRNNGuard, self).__init__(rnn.helper.main_program)
if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)):
raise TypeError(
"BlockGuardWithCompletion takes a StaticRNN or ParallelDo")
super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program)
self.rnn = rnn
def __enter__(self):
self.rnn.status = StaticRNN.IN_RNN_BLOCK
return super(StaticRNNGuard, self).__enter__()
return super(BlockGuardWithCompletion, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
self.rnn.complete_op()
return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val,
exc_tb)
class StaticRNNMemoryLink(object):
......@@ -200,7 +301,7 @@ class StaticRNN(object):
self.seq_len = None
def step(self):
return StaticRNNGuard(self)
return BlockGuardWithCompletion(self)
def _assert_in_rnn_block_(self, method):
if self.status != StaticRNN.IN_RNN_BLOCK:
......@@ -316,7 +417,7 @@ class StaticRNN(object):
else:
return self.outputs
def complete_rnn_op(self):
def complete_op(self):
main_program = self.helper.main_program
rnn_block = main_program.current_block()
parent_block = self.parent_block()
......@@ -897,7 +998,7 @@ class ConditionalBlock(object):
out_list = [
parent_block.var(var_name) for var_name in parent_block.vars
if var_name not in intermediate
if var_name in intermediate
]
step_scope = parent_block.create_var(
......
......@@ -64,14 +64,14 @@ def fc(input,
is flattened: the first `num_flatten_dims`
dimensions will be flatten to form the first
dimension of the final matrix (height of the
matrix), and the rest `rank(X) - num_col_dims`
matrix), and the rest `rank(X) - num_flatten_dims`
dimensions are flattened to form the second
dimension of the final matrix (width of the matrix).
For example, suppose `X` is a 6-dimensional tensor
with a shape [2, 3, 4, 5, 6], and
`x_num_col_dims` = 3. Then, the flattened matrix
`num_flatten_dims` = 3. Then, the flattened matrix
will have a shape [2 x 3 x 4, 5 x 6] = [24, 30].
By default, `x_num_col_dims` is set to 1.
By default, `num_flatten_dims` is set to 1.
param_attr(ParamAttr|list): The parameter attribute for learnable
parameters/weights of the fully connected
layer.
......@@ -243,18 +243,21 @@ def gru_unit(input,
r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)
ch_t & = actNode(xc_t + W_c dot(r_t, h_{t-1}) + b_c)
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot((1-u_t), ch_{t-1}) + dot(u_t, h_t)
h_t & = dot((1-u_t), m_t) + dot(u_t, h_{t-1})
The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts -
:math:`xu_t`, :math:`xr_t` and :math:`xc_t`. This means that in order to
:math:`xu_t`, :math:`xr_t` and :math:`xm_t`. This means that in order to
implement a full GRU unit operator for an input, a fully
connected layer has to be applied, such that :math:`z_t = W_{fc}x_t`.
This layer has three outputs :math:`h_t`, :math:`dot(r_t, h_{t - 1})`
and concatenation of :math:`u_t`, :math:`r_t` and :math:`ch_t`.
The terms :math:`u_t` and :math:`r_t` represent the update and reset gates
of the GRU cell. Unlike LSTM, GRU has one lesser gate. However, there is
an intermediate candidate hidden output, which is denoted by :math:`m_t`.
This layer has three outputs :math:`h_t`, :math:`dot(r_t, h_{t-1})`
and concatenation of :math:`u_t`, :math:`r_t` and :math:`m_t`.
Args:
input (Variable): The fc transformed input value of current step.
......
......@@ -4,6 +4,7 @@ import numpy as np
import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05
import paddle.v2.fluid as fluid
import time
word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict)
......@@ -160,7 +161,8 @@ def main():
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
#place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
......@@ -174,6 +176,7 @@ def main():
embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place)
start_time = time.time()
batch_id = 0
for pass_id in xrange(PASS_NUM):
chunk_evaluator.reset(exe)
......@@ -191,6 +194,9 @@ def main():
f1_score) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall)
+ " pass_f1_score:" + str(pass_f1_score))
if batch_id != 0:
print("second per batch: " + str((time.time() - start_time)
/ batch_id))
# exit early for CI
exit(0)
......
import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest
......@@ -47,6 +49,7 @@ def conv2d_forward_naive(input, filter, group, conv_param):
class TestConv2dOp(OpTest):
def setUp(self):
core.use_cuda()
self.init_op_type()
self.init_group()
self.init_dilation()
......@@ -167,26 +170,31 @@ class TestWithDilation(TestConv2dOp):
#----------------Conv2dCudnn----------------
class TestCudnn(TestConv2dOp):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
class TestCudnnWithPad(TestWithPad):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
class TestCudnnWithStride(TestWithStride):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
class TestCudnnWithGroup(TestWithGroup):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
class TestCudnnWith1x1(TestWith1x1):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
......
import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid as fluid
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.backward import append_backward
import numpy as np
import paddle.v2.fluid.core as core
class ParallelOpTest(unittest.TestCase):
def setUp(self):
x = layers.data(
shape=[-1, 30, 40],
dtype='float32',
name='x',
append_batch_size=False,
stop_gradient=False)
places = fluid.default_main_program().global_block().create_var()
pd = layers.ParallelDo(places=places)
with pd.do():
data = pd.read_input(x)
hidden = layers.fc(input=data, size=7)
pd.write_output(hidden)
data = pd()
loss = layers.mean(x=data)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.run(fluid.default_main_program(),
feed={
x.name: np.random.uniform(0.1, 0.6,
(20, 30, 40)).astype("float32")
})
def test_forward(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册