提交 8a136d14 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_add_doc

......@@ -61,6 +61,7 @@ option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
# CMAKE_BUILD_TYPE
......@@ -193,7 +194,10 @@ set(EXTERNAL_LIBS
if(WITH_GPU)
include(cuda)
include(tensorrt)
endif(WITH_GPU)
include(external/anakin)
else()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
endif()
if(WITH_AMD_GPU)
find_package(HIP)
......
if (NOT WITH_ANAKIN)
return()
endif()
set(ANAKIN_INSTALL_DIR "${THIRD_PARTY_PATH}/install/anakin" CACHE PATH
"Anakin install path." FORCE)
set(ANAKIN_INCLUDE "${ANAKIN_INSTALL_DIR}" CACHE STRING "root of Anakin header files")
set(ANAKIN_LIBRARY "${ANAKIN_INSTALL_DIR}" CACHE STRING "path of Anakin library")
set(ANAKIN_COMPILE_EXTRA_FLAGS -Wno-error=unused-variable -Wno-error=format-extra-args -Wno-error=comment -Wno-error=format -Wno-error=switch -Wno-error=return-type -Wno-error=non-virtual-dtor -Wno-reorder -Wno-error=cpp)
set(ANAKIN_LIBRARY_URL "https://github.com/pangge/Anakin/releases/download/3.0/anakin_release_simple.tar.gz")
# A helper function used in Anakin, currently, to use it, one need to recursively include
# nearly all the header files.
function(fetch_include_recursively root_dir)
if (IS_DIRECTORY ${root_dir})
include_directories(${root_dir})
endif()
file(GLOB ALL_SUB RELATIVE ${root_dir} ${root_dir}/*)
foreach(sub ${ALL_SUB})
if (IS_DIRECTORY ${root_dir}/${sub})
fetch_include_recursively(${root_dir}/${sub})
endif()
endforeach()
endfunction()
# download library
message(STATUS "Download Anakin library from ${ANAKIN_LIBRARY_URL}")
execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}")
execute_process(COMMAND bash -c "rm -rf ${ANAKIN_INSTALL_DIR}/*")
execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; wget -q ${ANAKIN_LIBRARY_URL}")
execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}")
execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; tar xzf anakin_release_simple.tar.gz")
if (WITH_ANAKIN)
message(STATUS "Anakin for inference is enabled")
message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}")
fetch_include_recursively(${ANAKIN_INCLUDE})
link_directories(${ANAKIN_LIBRARY})
endif()
......@@ -29,6 +29,8 @@ IF(NOT ${CBLAS_FOUND})
"${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}"
CACHE FILEPATH "openblas library." FORCE)
ADD_DEFINITIONS(-DPADDLE_USE_OPENBLAS)
SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable")
SET(OPENBLAS_COMMIT "v0.2.20")
......
#!/bin/bash
python gen_doc.py layers --submodules control_flow device io nn ops tensor detection > layers.rst
python gen_doc.py layers --submodules control_flow device io nn ops tensor detection learning_rate_scheduler > layers.rst
for module in data_feeder clip metrics executor initializer io nets optimizer param_attr profiler regularizer
do
......
......@@ -1041,3 +1041,42 @@ box_coder
.. autofunction:: paddle.fluid.layers.box_coder
:noindex:
learning_rate_scheduler
=======================
exponential_decay
-----------------
.. autofunction:: paddle.fluid.layers.exponential_decay
:noindex:
natural_exp_decay
-----------------
.. autofunction:: paddle.fluid.layers.natural_exp_decay
:noindex:
inverse_time_decay
------------------
.. autofunction:: paddle.fluid.layers.inverse_time_decay
:noindex:
polynomial_decay
----------------
.. autofunction:: paddle.fluid.layers.polynomial_decay
:noindex:
piecewise_decay
---------------
.. autofunction:: paddle.fluid.layers.piecewise_decay
:noindex:
noam_decay
----------
.. autofunction:: paddle.fluid.layers.noam_decay
:noindex:
......@@ -171,7 +171,7 @@ Pytorch chooses immediate evaluation. It avoids ever materializing a "forward gr
## What can fluid learn from them?
TBD
Please refer to `paddle/contrib/dynamic/`.
# Appendix
......
......@@ -14,3 +14,4 @@
#
add_subdirectory(inference)
add_subdirectory(tape)
......@@ -17,48 +17,9 @@ if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE)
set(ANAKIN_INCLUDE "" CACHE STRING "root of Anakin header files")
set(ANAKIN_LIBRARY "" CACHE STRING "path of Anakin library")
set(inference_deps paddle_inference_api paddle_fluid_api)
# if anakin is set enable anakin api implementation
if(ANAKIN_INCLUDE AND ANAKIN_LIBRARY)
set(ANAKIN_FOUND ON)
else()
set(ANAKIN_FOUND OFF)
endif()
function(fetch_include_recursively root_dir)
if (IS_DIRECTORY ${root_dir})
include_directories(${root_dir})
endif()
file(GLOB ALL_SUB RELATIVE ${root_dir} ${root_dir}/*)
foreach(sub ${ALL_SUB})
if (IS_DIRECTORY ${root_dir}/${sub})
fetch_include_recursively(${root_dir}/${sub})
endif()
endforeach()
endfunction()
if (ANAKIN_FOUND)
# Anakin's code style doesn't follow google c style.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=unused-variable -Wno-error=format-extra-args -Wno-error=comment -Wno-error=format -Wno-error=switch -Wno-error=return-type -Wno-error=non-virtual-dtor -Wno-reorder -Wno-error=cpp")
message(STATUS "Anakin for inference is enabled")
message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}")
fetch_include_recursively(${ANAKIN_INCLUDE})
link_directories(${ANAKIN_LIBRARY})
nv_library(inference_anakin_api SHARED SRCS paddle_inference_api.cc paddle_inference_api_anakin_engine.cc)
target_link_libraries(inference_anakin_api anakin anakin_saber_common)
list(APPEND inference_deps inference_anakin_api)
endif()
function(inference_api_test TARGET_NAME)
if (WITH_TESTING)
set(options "")
......@@ -79,7 +40,7 @@ function(inference_api_test TARGET_NAME)
endfunction(inference_api_test)
cc_library(paddle_inference_api
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
cc_test(test_paddle_inference_api
......@@ -89,9 +50,17 @@ cc_test(test_paddle_inference_api
inference_api_test(test_paddle_inference_api_impl
ARGS test_word2vec test_image_classification)
if (ANAKIN_FOUND)
if (WITH_ANAKIN)
# Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's,
# so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to
# compile the libinference_anakin_api.a and compile with anakin.so.
nv_library(inference_anakin_api SHARED SRCS paddle_inference_api.cc paddle_inference_api_anakin_engine.cc)
target_compile_options(inference_anakin_api BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
target_link_libraries(inference_anakin_api anakin anakin_saber_common)
cc_test(inference_anakin_test SRCS paddle_inference_api_anakin_engine_tester.cc
DEPS ${inference_deps})
ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin
DEPS inference_anakin_api)
target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
endif()
if(WITH_TESTING)
......
......@@ -12,9 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include "paddle/contrib/inference/paddle_inference_api_anakin_engine.h"
#include <cuda.h>
namespace paddle {
......
......@@ -19,10 +19,9 @@ limitations under the License. */
#pragma once
// NOTE This header file do not have namespace.
//#include <test/framework/net/paddle_api.h>
#include "paddle/contrib/inference/paddle_inference_api.h"
// from anakin
#include "framework/core/net/net.h"
#include "saber/saber_types.h"
......
......@@ -12,17 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "gflags/gflags.h"
#include "paddle/contrib/inference/paddle_inference_api.h"
DEFINE_string(model, "", "Directory of the inference model.");
namespace paddle {
AnakinConfig GetConfig() {
AnakinConfig config;
config.model_file = "./mobilenet_v2.anakin.bin";
config.model_file = FLAGS_model;
config.device = 0;
config.max_batch_size = 1;
return config;
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE)
cc_library(tape_variable SRCS variable.cc DEPS ${FLUID_CORE_MODULES})
cc_library(tape SRCS tape.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} tape_variable)
cc_test(test_tape
SRCS test_tape.cc
DEPS tape tape_variable)
# Dynamic Graph on Fluid
PaddlePaddle Fluid is targeting the autodiff without tape, which, however, is very
challenging and we are still way from there. DyNet and PyTorch provide a good design
idea, the *tape*, that significantly eases the challenge. Also, DyNet provides
a C++ API that is as convenient as Python but with higher efficiency and could
conveniently integrate with industrial/production systems. This package, `tape`,
combines the good of
1. tape from PyTorch and DyNet
2. C++ API and core from DyNet
3. rich set of operators from PaddlePaddle
## Overview
We can implement Dynet-like Tape(See this [survey](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/survey/dynamic_graph.md))
by wrapping Paddle Fluid's `Operator` and `Variable`.
The user API is straight forward since
1. it is imperative. And it uses host language's control flow logic.
1. it avoids extra concepts such as `Scope` and `Executor`.
All of these benefits come at the cost of just adding one line `reset_global_tape`
at every iteration.
## Code Structure
In short, the `Tape` contains a vector of `OpHandle`s. And an `OpHandle` contains its
`type`, the pointers to the `Variable`s, and necessary attributes.
```c++
class Variable {
public:
VriableHandle Grad(); // returns its gradient variable
private:
framework::VarDesc desc_; // compile time infershape, necessary for lazy execution
framework::Variable var_; // run time variable, holds data memory
};
using VariableHandle = shared_ptr<Variable>;
struct OpHandle {
string type_;
map<string, vector<VariableHandle>> inputs_;
map<string, vector<VariableHandle>> outputs_;
AttributeMap attrs_;
};
class Tape {
public:
void AddOp(OpHandle); // add op
void Forward(); // execute the tape_
void Backward(); // execute the backward of the tape_
private:
vector<OpHandle> tape_;
};
```
We uses `Function` to indicate layers. It takes care of parameter
initialization and `AddOp` to the Tape when it is called.
```c++
class Linear {
public:
Linear(int in_dim, int out_dim, const std::string &act)
: w_(new Variable("LinearWeight")),
b_(new Variable("LinearBias")),
act_(act) {
Tape init_tape;
std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{in_dim, out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {w_}}}, attrs);
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {b_}}}, attrs);
init_tape.Forward();
}
VariableHandle operator()(VariableHandle input) {
VariableHandle pre_bias(new Variable("linear"));
get_global_tape().AddOp("mul",
{{"X", {input}}, {"Y", {w_}}},
{{"Out", {pre_bias}}},
{{"x_num_col_dims", 1}, {"y_num_col_dims", 1}});
VariableHandle pre_act(new Variable("linear"));
get_global_tape().AddOp("elementwise_add",
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
{{"axis", 1}});
VariableHandle post_act(new Variable("linear"));
get_global_tape().AddOp(act_,
{{"X", {pre_act}}},
{{"Out", {post_act}}},
{});
return post_act;
}
std::vector<VariableHandle> Params() { return {w_, b_}; }
private:
VariableHandle w_;
VariableHandle b_;
std::string act_;
};
```
## User API
```c++
// Model function
paddle::tape::Linear linear1(3, 3, "relu"); // init weight and bias
paddle::tape::Linear linear2(3, 3, "relu"); // init weight and bias
paddle::tape::Mean mean;
// Optimizer
paddle::tape::SGD sgd(0.001);
// Data Feeder
paddle::tape::Fill data_feeder(...);
VariableHandle input(new paddle::tape::Variable("input"));
VariableHandle label(new paddle::tape::Variable("label"));
for (int i = 0; i < 2; ++i) {
reset_global_tape();
data_feeder(input, label);
auto loss = softmax(linear2(linear1(input)), label); // compile time InferShape & InferVarType
LOG(INFO) << loss.value(); // Run forward up to loss
// Run backward, store gradient of w at w->Grad()
get_global_tape.Backward(loss);
// Update w
sgd(linear1.Params());
sgd(linear2.Params());
}
```
<details>
<summary></summary>
digraph G {
subgraph cluster_0 {
node [shape=record,style=filled];
style=filled;
color=lightgrey;
linear1 [label="{type: mul | {input | {<before_mul1>X: before_mul1 |<weight1> Y: weight1}} | {output |<before_bias1> Out: before_bias1}}"];
elementwise_add1 [label="{type: elementwise_add | {input | {<before_bias1>X: before_bias1 |<bias1> Y: bias1}} | {output |<before_act1> Out: before_act1}}"];
relu1 [label="{type: relu | {input | {<before_act1>X: before_act1 }} | {output |<after_act1> Out: after_act1}}"];
linear1 -> elementwise_add1->relu1;
label = "forward tape";
}
linear1:before_mul1->before_mul1
linear1:weight1->weight1
linear1:before_bias1->before_bias1
elementwise_add1:bias1->bias1
elementwise_add1:before_bias1->before_bias1
elementwise_add1:before_act1->before_act1
relu1:before_act1->before_act1
relu1:after_act1->after_act1
subgraph cluster_1 {
node [shape=record,style=filled];
style=filled;
color=lightgrey;
linear1_grad [label="{type: mul_grad | {input | {<before_mul1>X: before_mul1 |<weight1> Y: weight1|<before_bias1_grad> Out_grad: before_bias1_grad}} | {output |{<before_mul1_grad>X_grad: before_mul1_grad |<weight1_grad> Y_grad: weight1_grad}}}"];
elementwise_add1_grad [label="{type: elementwise_add_grad | {input | <before_act1_grad> Out_grad: before_act1_grad} | {output |{<before_bias1_grad>X_grad: before_bias1_grad |<bias1_grad> Y_grad: bias1_grad}}}"];
relu1_grad [label="{type: relu_grad | {input |<after_act1_grad> Out_grad: after_act1_grad} | {ouput | {<before_act1_grad>X_grad: before_act1_grad }}}"];
linear1_grad -> elementwise_add1_grad ->relu1_grad [dir=back];
label = "backward tape";
}
relu1_grad:after_act1_grad->after_act1_grad
relu1_grad:before_act1_grad->before_act1_grad
elementwise_add1_grad:before_act1_grad->before_act1_grad
elementwise_add1_grad:before_bias1_grad->before_bias1_grad
elementwise_add1_grad:bias1_grad->bias1_grad
linear1_grad:before_mul1->before_mul1
linear1_grad:weight1->weight1
linear1_grad:before_bias1_grad->before_bias1_grad
linear1_grad:before_mul1_grad->before_mul1_grad
linear1_grad:weight1_grad->weight1_grad
subgraph cluster_2 {
node [shape=record];
label = "Linear1";
weight1
bias1
}
weight1 -> weight1_grad [ label="Grad()", style="dashed" ];
bias1 -> bias1_grad [ label="Grad()", style="dashed"];
}
</details>
![Image](https://github.com/tonyyang-svail/Paddle/blob/cpp_tap/paddle/contrib/tape/computation_graph.png)
## Code Reuse
We want to stay close to Paddle Fluid as much as possible.
### Reuse All Operators
As all Ops are registered at `OpInfoMap`, the effort of adding a new `Function`
is about 10 lines of code, similar to expose an operator to Python.
### Reuse Compile Time InferShape and InferVarType
Note that all the symbolic information is stored at `tape::Varaible::desc_`, instead
of `ProgramDesc.block.vars`, we create a temporary `BlockDesc` to do `InferShape` and
`InferVarType` every time we `AddOp` to the tape.
### Reuse Operator::Run
We use smart pointer, instead of `Scope`, to manage memory. So we create a temporary
`Scope` for every `Operator::Run()`.
## Possible Feature
### Release Memory on Backward
We can release memory aggressively. During backward, we can delete the OpHandle once
we have finished its backward. Since all the variable is managed by smart pointer, the
memory is automatically released when its `ref_count` goes to 0.
### Kernel Fusion
As a symbolic representation of the Tape is constructed first before the actual
execution, it would be possible to perform graph optimization. One use case is kernel
fusion.
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/contrib/tape/tape.h"
#include "paddle/contrib/tape/variable.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
namespace tape {
class Function {};
class Fill {
public:
Fill(const std::string &initializer, const framework::AttributeMap &attrs)
: initializer_(initializer), attrs_(attrs) {}
void operator()(VariableHandle var) {
get_global_tape().AddOp(initializer_, {}, {{"Out", {var}}}, attrs_);
}
private:
const std::string initializer_;
const framework::AttributeMap attrs_;
};
class Mean {
public:
VariableHandle operator()(VariableHandle var) {
VariableHandle out(new Variable("mean"));
get_global_tape().AddOp("mean", {{"X", {var}}}, {{"Out", {out}}}, {});
return out;
}
};
class Linear {
public:
Linear(int in_dim, int out_dim, const std::string &act)
: w_(new Variable("LinearWeight")),
b_(new Variable("LinearBias")),
act_(act) {
Tape init_tape;
std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{in_dim, out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {w_}}}, attrs);
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{out_dim};
attrs["value"] = 1.0f;
init_tape.AddOp(initializer, {}, {{"Out", {b_}}}, attrs);
init_tape.Forward();
}
VariableHandle operator()(VariableHandle input) {
VariableHandle pre_bias(new Variable("linear"));
get_global_tape().AddOp("mul",
{{"X", {input}}, {"Y", {w_}}},
{{"Out", {pre_bias}}},
{{"x_num_col_dims", 1}, {"y_num_col_dims", 1}});
VariableHandle pre_act(new Variable("linear"));
get_global_tape().AddOp("elementwise_add",
{{"X", {pre_bias}}, {"Y", {b_}}},
{{"Out", {pre_act}}},
{{"axis", 1}});
VariableHandle post_act(new Variable("linear"));
get_global_tape().AddOp(
act_, {{"X", {pre_act}}}, {{"Out", {post_act}}}, {});
return post_act;
}
std::vector<VariableHandle> Params() { return {w_, b_}; }
private:
VariableHandle w_;
VariableHandle b_;
std::string act_;
};
class SGD {
public:
SGD(float learning_rate) : learning_rate_(new Variable("sgd")) {
Tape init_tape;
std::string initializer = "fill_constant";
framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{1};
attrs["value"] = learning_rate;
init_tape.AddOp(initializer, {}, {{"Out", {learning_rate_}}}, attrs);
init_tape.Forward();
}
void operator()(VariableHandle input) {
PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(),
"optimization must happen after the backward");
Tape temp_tape;
temp_tape.AddOp("sgd",
{{"Param", {input}},
{"LearningRate", {learning_rate_}},
{"Grad", {input->Grad()}}},
{{"ParamOut", {input}}},
{});
temp_tape.Forward();
}
private:
VariableHandle learning_rate_;
};
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/contrib/tape/tape.h"
#include <list>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace tape {
// borrowed from
// https://stackoverflow.com/questions/874134/find-if-string-ends-with-another-string-in-c
inline bool ends_with(std::string const &value, std::string const &ending) {
if (ending.size() > value.size()) return false;
return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
}
std::ostream &operator<<(std::ostream &os, const framework::VarDesc &var_desc) {
os << var_desc.Name();
os << "[" << var_desc.GetType() << "]";
os << "[" << var_desc.GetDataType() << "]";
os << "{";
for (auto &i : var_desc.GetShape()) {
os << i << ",";
}
os << "}";
return os;
}
std::string to_string(const std::string &type,
const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars,
const framework::AttributeMap &attrs) {
std::stringstream ss;
ss << type << " ";
for (auto &param_name : in_vars) {
for (auto &var : param_name.second) {
ss << param_name.first << ":(" << var->Desc() << ") ";
}
}
for (auto &param_name : out_vars) {
for (auto &var : param_name.second) {
ss << param_name.first << ":(" << var->Desc() << ") ";
}
}
return ss.str();
}
framework::OpDesc CreateOpDesc(const std::string &type,
const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars,
const framework::AttributeMap &attrs) {
framework::VariableNameMap inputs;
for (auto &param_name : in_vars) {
for (auto &var : param_name.second) {
inputs[param_name.first].emplace_back(var->Name());
}
}
framework::VariableNameMap outputs;
for (auto &param_name : out_vars) {
for (auto &var : param_name.second) {
outputs[param_name.first].emplace_back(var->Name());
}
}
return framework::OpDesc(type, inputs, outputs, attrs);
}
void InferShapeAndVarType(const std::string &type,
const VariableHandleMap &in_vars,
VariableHandleMap *out_vars,
const framework::AttributeMap &attrs) {
framework::OpDesc op_desc = CreateOpDesc(type, in_vars, *out_vars, attrs);
// Create a temporary block for compile-time
framework::ProgramDesc program_desc;
framework::BlockDesc *block_desc = program_desc.MutableBlock(0);
PADDLE_ENFORCE(block_desc);
for (auto &param_name : in_vars) {
for (auto &var : param_name.second) {
*block_desc->Var(var->Name())->Proto() = *var->MutableDesc()->Proto();
}
}
for (auto &param_name : *out_vars) {
for (auto &var : param_name.second) {
*block_desc->Var(var->Name())->Proto() = *var->MutableDesc()->Proto();
}
}
LOG(INFO) << "- " << to_string(type, in_vars, *out_vars, attrs);
op_desc.InferShape(*block_desc);
op_desc.InferVarType(block_desc);
for (auto &param_name : *out_vars) {
for (auto &var : param_name.second) {
*var->MutableDesc()->Proto() = *block_desc->Var(var->Name())->Proto();
}
}
LOG(INFO) << "+ " << to_string(type, in_vars, *out_vars, attrs);
}
void Tape::AddOp(const std::string &type,
const VariableHandleMap &in_vars,
VariableHandleMap out_vars,
const framework::AttributeMap &attrs) {
InferShapeAndVarType(type, in_vars, &out_vars, attrs);
tape_.emplace_back(type, in_vars, out_vars, attrs);
}
// Temporary Scope for Operator::Run()
class ScopeWrapper : public framework::Scope {
public:
ScopeWrapper(const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars) {
for (auto &v : in_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
}
}
}
for (auto &v : out_vars) {
for (auto &vv : v.second) {
if (!vars_.count(vv->Name())) {
vars_[vv->Name()].reset(vv->Var());
}
}
}
}
~ScopeWrapper() {
for (auto &pair : vars_) {
pair.second.release();
}
}
};
void Tape::Forward() {
LOG(INFO) << "Starting forward -------------------------";
PADDLE_ENFORCE(!has_been_backwarded_);
while (current_position_ < tape_.size()) {
OpHandle &op = tape_[current_position_];
// Create Output Tensor, this is only necessary for OpWithKernel
for (auto &param2var : op.outputs_) {
for (auto &var : param2var.second) {
var->InitializeVariable();
}
}
framework::OpDesc op_desc =
CreateOpDesc(op.type_, op.inputs_, op.outputs_, op.attrs_);
ScopeWrapper scope(op.inputs_, op.outputs_);
framework::OpRegistry::CreateOp(op_desc)->Run(scope, platform::CPUPlace());
current_position_++;
}
LOG(INFO) << "Finishing forward -------------------------";
}
void Tape::Backward(VariableHandle target) {
PADDLE_ENFORCE(!has_been_backwarded_);
Forward();
// TODO(tonyyang-svail): check output of last op is target
backward_tape_.reset(new Tape());
framework::AttributeMap attrs;
// FIXME(tonyyang-svail): Need to infer_data_type
attrs["dtype"] = framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{1};
attrs["value"] = 1.0f;
backward_tape_->AddOp(
"fill_constant", {}, {{"Out", {target->Grad()}}}, attrs);
for (auto it = tape_.rbegin(); it != tape_.rend(); ++it) {
framework::OpDesc op_desc =
CreateOpDesc(it->type_, it->inputs_, it->outputs_, it->attrs_);
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, {}, &grad_to_var, {});
for (auto &op_desc : grad_op_descs) {
std::unordered_map<std::string, VariableHandle> name2var;
for (auto &param2vars : it->inputs_) {
for (auto &a : param2vars.second) {
name2var[a->Name()] = a;
}
}
for (auto &param2vars : it->outputs_) {
for (auto &a : param2vars.second) {
name2var[a->Name()] = a;
}
}
VariableHandleMap in_vars;
VariableHandleMap out_vars;
std::map<const framework::VariableNameMap *, VariableHandleMap *>
loop_over{{&op_desc->Inputs(), &in_vars},
{&op_desc->Outputs(), &out_vars}};
for (auto &each : loop_over) {
auto &vmp = *each.first;
auto &vhm = *each.second;
for (auto &p2a : vmp) {
for (auto &argu : p2a.second) {
if (name2var.count(argu)) {
vhm[p2a.first].push_back(name2var[argu]);
} else {
PADDLE_ENFORCE(ends_with(argu, framework::kGradVarSuffix),
argu.c_str());
std::string name = argu.substr(
0, argu.size() - std::strlen(framework::kGradVarSuffix));
PADDLE_ENFORCE(name2var.count(name), name.c_str());
vhm[p2a.first].push_back(name2var[name]->Grad());
}
}
}
}
backward_tape_->AddOp(
op_desc->Type(), in_vars, out_vars, op_desc->GetAttrMap());
}
// TODO(tonyyang-svail): how to fill empty grad?
// TODO(tonyyang-svail): Sum var grad is necessary
}
backward_tape_->Forward();
has_been_backwarded_ = true;
}
Tape &get_global_tape() {
static Tape T;
return T;
}
void reset_global_tape() { get_global_tape() = Tape(); }
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/contrib/tape/variable.h"
namespace paddle {
namespace tape {
using VariableHandleMap = std::map<std::string, std::vector<VariableHandle>>;
struct OpHandle {
OpHandle(const std::string &type,
const VariableHandleMap &in_vars,
const VariableHandleMap &out_vars,
const framework::AttributeMap &attrs)
: type_(type), inputs_(in_vars), outputs_(out_vars), attrs_(attrs) {}
std::string type_;
VariableHandleMap inputs_;
VariableHandleMap outputs_;
framework::AttributeMap attrs_;
};
class Tape {
public:
void AddOp(const std::string &type,
const VariableHandleMap &in_vars,
VariableHandleMap out_vars,
const framework::AttributeMap &attrs);
void Forward();
void Backward(VariableHandle target);
bool HasBeenBackwarded() { return has_been_backwarded_; }
private:
bool has_been_backwarded_ = false;
size_t current_position_ = 0;
std::vector<OpHandle> tape_;
std::shared_ptr<Tape> backward_tape_;
};
Tape &get_global_tape();
void reset_global_tape();
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/contrib/tape/function.h"
using namespace paddle::tape;
TEST(Tape, TestMLP) {
LOG(INFO) << "TestMLP";
Linear linear1(3, 3, "relu");
Linear linear2(3, 3, "relu");
Mean mean;
SGD sgd(0.001);
std::string initializer = "fill_constant";
paddle::framework::AttributeMap attrs;
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
attrs["shape"] = std::vector<int>{3, 3};
attrs["value"] = 1.0f;
Fill filler(initializer, attrs);
for (int i = 0; i < 2; ++i) {
reset_global_tape();
VariableHandle input(new Variable("input"));
filler(input);
auto loss = mean(linear2(linear1(input)));
get_global_tape().Backward(loss);
for (auto w : linear1.Params()) {
sgd(w);
}
for (auto w : linear2.Params()) {
sgd(w);
}
}
}
int main(int argc, char** argv) {
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/contrib/tape/variable.h"
namespace paddle {
namespace tape {
void Variable::InitializeVariable() {
LOG(INFO) << "Initialzing " << desc_.Name() << " as " << desc_.GetType();
framework::proto::VarType::Type var_type = desc_.GetType();
if (var_type == framework::proto::VarType::LOD_TENSOR) {
var_.GetMutable<framework::LoDTensor>();
} else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
var_.GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW("Variable type %d is not in [LOD_TENSOR, SELECTED_ROWS]",
var_type);
}
}
}
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/operator.h" // framework::kGradVarSuffix
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace tape {
class Variable;
using VariableHandle = std::shared_ptr<Variable>;
/*
* Combination of
* framework::VarDesc desc_;
* framework::Variable var_;
*/
class Variable {
public:
Variable(const std::string pre_fix)
: desc_(pre_fix + std::to_string(count())) {}
Variable(const std::string pre_fix, bool is_grad)
: desc_(pre_fix + (is_grad ? framework::kGradVarSuffix
: std::to_string(count()))) {}
~Variable() { LOG(INFO) << "Deleting " << Name(); }
// Instantiate LoDTensor/SelectedRow
void InitializeVariable();
VariableHandle Grad() {
if (grad_.expired()) {
VariableHandle new_grad(new Variable(desc_.Name(), true));
grad_ = new_grad;
return new_grad;
} else {
return VariableHandle(grad_);
}
}
// Stochastic Gradient Descent with Momentum
// VariableHandle Momentum ();
// void init(const std::string& initializer,
// const framework::AttributeMap& attrs);
// void value() {};
const framework::VarDesc& Desc() const { return desc_; }
framework::VarDesc* MutableDesc() { return &desc_; }
// TODO(tonyyang-svail): No need to expose name
std::string Name() const { return desc_.Name(); }
framework::Variable* Var() { return &var_; }
private:
int count() {
static int counter = 0;
return counter++;
}
framework::VarDesc desc_;
framework::Variable var_;
std::weak_ptr<Variable> grad_;
};
}
}
......@@ -330,8 +330,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
// NOTE! Please do not delete this line, it's usefull because the debug
// string before and after op.run are different, after run the output
// will have right shape which is usefull for debug.
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
......@@ -69,6 +69,19 @@ static DDim GetDims(const Scope& scope, const std::string& name,
}
}
static int GetRowSize(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return -1;
}
if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().rows().size();
}
return -1;
}
static LoD GetLoD(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}});
......@@ -85,6 +98,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(10) << "- " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
......@@ -94,6 +108,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
#endif
}
RunImpl(scope, place);
VLOG(10) << "+ " << DebugStringEx(&scope);
}
bool OperatorBase::HasInputs(const std::string& name) const {
......@@ -153,6 +168,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (scope) {
int row_size = GetRowSize(*scope, input.second[i]);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")";
}
......@@ -173,6 +192,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (scope) {
int row_size = GetRowSize(*scope, output.second[i]);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
ss << "[" << GetDims(*scope, output.second[i], true) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")";
}
......
......@@ -145,9 +145,9 @@ void ParallelExecutor::BCastParamsToGPUs(
auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) {
#ifdef PADDLE_WITH_CUDA
std::vector<void *> buffers;
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i];
void *buffer;
......@@ -159,11 +159,21 @@ void ParallelExecutor::BCastParamsToGPUs(
t->Resize(dims);
buffer = t->mutable_data(place, main_tensor.type());
}
auto &nccl_ctx = member_->nccl_ctxs_->at(place);
platform::dynload::ncclBcast(buffer, numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
buffers.push_back(buffer);
}
member_->nccl_ctxs_->WaitAll();
PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(),
"variables' buffer size to bcast NOT equal to places");
{
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream());
}
member_->nccl_ctxs_->WaitAll();
}
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
......
......@@ -81,6 +81,9 @@ class Scope {
// Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const;
protected:
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
private:
// Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {}
......@@ -93,8 +96,6 @@ class Scope {
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const;
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_;
Scope const* parent_{nullptr};
......
......@@ -20,16 +20,20 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/pybind/pybind.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
DEFINE_bool(init_p2p, false, "Whether to init p2p.");
DEFINE_int32(math_num_threads, 1,
"Number of threads used to run math functions.");
namespace paddle {
namespace inference {
void Init(const std::vector<std::string> argv) {
framework::InitGflags(argv);
operators::math::SetNumThreads(FLAGS_math_num_threads);
// init devices
std::vector<int> devices;
std::string token;
......
......@@ -169,7 +169,8 @@ class RequestPrefetch final : public RequestBase {
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = scope->FindVar(out_var_name);
// out var must be created in local scope!
framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
......
......@@ -20,13 +20,16 @@
#ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_service.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#ifdef LAPACK_FOUND
#include <lapacke.h>
#endif
#endif
#ifndef LAPACK_FOUND
extern "C" {
......@@ -46,6 +49,18 @@ namespace paddle {
namespace operators {
namespace math {
static void SetNumThreads(int num_threads) {
#ifdef PADDLE_USE_OPENBLAS
int real_num_threads = num_threads > 1 ? num_threads : 1;
openblas_set_num_threads(real_num_threads);
#elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1;
mkl_set_num_threads(real_num_threads);
#else
PADDLE_ENFORCE(false, "To be implemented.");
#endif
}
/**
* Matrix Descriptor of a memory buffer.
*
......
......@@ -21,8 +21,10 @@ limitations under the License. */
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#ifdef LAPACK_FOUND
#include <lapacke.h>
#endif
#endif
#ifndef LAPACK_FOUND
extern "C" {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mean_iou_op.h"
namespace paddle {
namespace operators {
class MeanIoUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predictions"),
"Input (Predictions) of MeanIoU op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input (labels) of MeanIoU op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutMeanIou"),
"Output (OutMeanIou) of MeanIoU op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutWrong"),
"Output (OutWrong) of MeanIoU op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutCorrect"),
"Output (OutWrong) of MeanIoU op should not be null.");
int64_t num_classes =
static_cast<int64_t>(ctx->Attrs().Get<int>("num_classes"));
ctx->SetOutputDim("OutMeanIou", {1});
ctx->SetOutputDim("OutWrong", {num_classes});
ctx->SetOutputDim("OutCorrect", {num_classes});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()),
ctx.GetPlace());
}
};
class MeanIoUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Predictions",
"(Tensor), A Tensor of prediction results for semantic labels"
" with type int32 or int64. The rank should be greater than 1.");
AddInput(
"Labels",
"(Tensor), A Tensor of ground truth labels with type int32 or int64."
"Its shape should be the same as Input(Predictions).");
AddInput("InWrongs",
"(vector<Tensor>), A list of Tensor with shape "
"[num_classes]. They are used to collect wrong number among "
"batches. Empty list is also valid here.")
.AsDuplicable()
.AsDispensable();
AddInput(
"InCorrects",
"(vector<Tensor>), A list of Tensor with shape "
"[num_classes]. They are used to collect correct number among batches. "
"Empty list is also valid here.")
.AsDuplicable()
.AsDispensable();
AddInput("InMeanIou",
"(vector<Tensor>), A list of Tensor that Output(mean_iou) should "
"be added to. Empty list is also valid here.")
.AsDuplicable()
.AsDispensable();
AddOutput("OutMeanIou",
"(vector<Tensor>), A Tensor representing the"
" mean intersection-over-union with shape [1].");
AddOutput("OutWrong", "(Tensor), A Tensor with shape [num_classes]. ");
AddOutput("OutCorrect", "(Tensor), A Tensor with shape [num_classes]. ");
AddAttr<int>("num_classes", "(int), The possible number of labels.");
AddComment(R"DOC(
mean-IOU Operator.
Mean Intersection-Over-Union is a common evaluation metric for
semantic image segmentation, which first computes the IOU for each
semantic class and then computes the average over classes.
IOU is defined as follows:
IOU = true_positive / (true_positive + false_positive + false_negative).
It is based on pixel level area while "IOU Similarity Operator"
is based on area of rectangle.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mean_iou, ops::MeanIoUOp, ops::MeanIoUOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(mean_iou, ops::MeanIoUKernel<int>,
ops::MeanIoUKernel<int32_t>,
ops::MeanIoUKernel<int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/mean_iou_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void CountCUDAKernel(const int num_classes, const int count,
const T* predictions, const T* labels,
int* wrong, int* correct) {
extern __shared__ int blcok_cache[];
int* wrong_c = blcok_cache;
int* correct_c = blcok_cache + num_classes;
// init cache
for (int i = threadIdx.x; i < num_classes * 2; i += blockDim.x) {
blcok_cache[i] = 0;
}
__syncthreads();
T pred;
T label;
CUDA_1D_KERNEL_LOOP(i, count) {
pred = predictions[i];
label = labels[i];
if (pred == label) {
atomicAdd(correct_c + pred, 1);
} else {
atomicAdd(wrong_c + pred, 1);
atomicAdd(wrong_c + label, 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < num_classes; i += blockDim.x) {
atomicAdd(wrong + i, wrong_c[i]);
atomicAdd(correct + i, correct_c[i]);
}
}
__global__ void ComputeIoUCUDAKernel(const int num_classes, int* wrong,
int* correct, float* ious, float* iou) {
__shared__ int valid_count_c;
if (threadIdx.x == 0) {
valid_count_c = 0;
}
__syncthreads();
CUDA_1D_KERNEL_LOOP(i, num_classes) {
int wrong_n = wrong[i];
int correct_n = correct[i];
int denominator = wrong_n + correct_n;
if (denominator > 0) {
atomicAdd(&valid_count_c, 1);
ious[i] = static_cast<float>(correct_n) / denominator;
} else {
ious[i] = 0;
}
}
__syncthreads();
if (threadIdx.x == 0) {
float iou_sum = 0;
for (int i = 0; i < num_classes; ++i) {
iou_sum += ious[i];
}
iou[0] += iou_sum / valid_count_c;
}
}
template <typename T>
class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
// get input and output tensor
auto* predictions = ctx.Input<Tensor>("Predictions");
auto* labels = ctx.Input<Tensor>("Labels");
auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou");
auto* out_wrong = ctx.Output<Tensor>("OutWrong");
auto* out_correct = ctx.Output<Tensor>("OutCorrect");
int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));
// Get data ptr
const T* predictions_data = predictions->data<T>();
const T* labels_data = labels->data<T>();
int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
float* out_mean_iou_data =
out_mean_iou->mutable_data<float>(ctx.GetPlace());
// Get Eigen tensor
auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
// Temporary tensor
Tensor ious;
float* ious_data = ious.mutable_data<float>(
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
auto ious_t = EigenTensor<float, 1>::From(ious);
// Init out_wrong, out_correct and out_mean_iou
out_wrong_t.device(place) = out_wrong_t.constant(0);
out_correct_t.device(place) = out_correct_t.constant(0);
out_mean_iou_t.device(place) = out_mean_iou_t.constant(0.0f);
// collect pre wrong, correct and mean_iou
auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou");
for (int i = 0; i < in_mean_ious.size(); ++i) {
out_mean_iou_t.device(place) +=
EigenTensor<float, 1>::From(*in_mean_ious[i]);
}
auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs");
for (int i = 0; i < in_wrongs.size(); ++i) {
out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
}
auto in_corrects = ctx.MultiInput<Tensor>("InCorrects");
for (int i = 0; i < in_corrects.size(); ++i) {
out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
}
// compute
auto stream = ctx.cuda_device_context().stream();
int block = PADDLE_CUDA_NUM_THREADS;
int grid = (predictions->numel() + block - 1) / block;
int cache_size = (num_classes * 2 + 1) * sizeof(int);
CountCUDAKernel<T><<<grid, block, cache_size, stream>>>(
num_classes, predictions->numel(), predictions_data, labels_data,
out_wrong_data, out_correct_data);
ctx.device_context().Wait();
ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes, out_wrong_data,
out_correct_data, ious_data,
out_mean_iou_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>,
ops::MeanIoUCUDAOpKernel<int64_t>,
ops::MeanIoUCUDAOpKernel<int32_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T>
class MeanIoUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
// get input and output tensor
auto* predictions = ctx.Input<Tensor>("Predictions");
auto* labels = ctx.Input<Tensor>("Labels");
auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou");
auto* out_wrong = ctx.Output<Tensor>("OutWrong");
auto* out_correct = ctx.Output<Tensor>("OutCorrect");
int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));
// get data ptr
const T* predictions_data = predictions->data<T>();
const T* labels_data = labels->data<T>();
float* out_mean_iou_data =
out_mean_iou->mutable_data<float>(ctx.GetPlace());
int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
// get eigen tensor
auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
// Tmp tensor
Tensor denominator;
Tensor valid_count;
Tensor iou_sum;
// get data ptr of tmp tensor
int* denominator_data = denominator.mutable_data<int>(
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
int* valid_count_data = valid_count.mutable_data<int>({1}, ctx.GetPlace());
float* iou_sum_data = iou_sum.mutable_data<float>({1}, ctx.GetPlace());
// get eigen tensor of tmp tensor
auto denominator_t = EigenTensor<int, 1>::From(denominator);
auto valid_count_t = EigenTensor<int, 1>::From(valid_count);
auto iou_sum_t = EigenTensor<float, 1>::From(iou_sum);
// init out_wrong, out_correct and out_mean_iou
out_wrong_t = out_wrong_t.constant(0);
out_correct_t = out_correct_t.constant(0);
out_mean_iou_t = out_mean_iou_t.constant(0);
// collect pre wrong, correct and mean_iou
auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou");
for (size_t i = 0; i < in_mean_ious.size(); ++i) {
out_mean_iou_t.device(place) +=
EigenTensor<float, 1>::From(*in_mean_ious[i]);
}
auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs");
for (size_t i = 0; i < in_wrongs.size(); ++i) {
out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
}
auto in_corrects = ctx.MultiInput<Tensor>("InCorrects");
for (size_t i = 0; i < in_corrects.size(); ++i) {
out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
}
// compute
for (int64_t i = 0; i < predictions->numel(); ++i) {
if (predictions_data[i] == labels_data[i]) {
out_correct_data[predictions_data[i]] += 1;
} else {
out_wrong_data[labels_data[i]] += 1;
out_wrong_data[predictions_data[i]] += 1;
}
}
denominator_t = out_wrong_t + out_correct_t;
valid_count_t =
(denominator_t > denominator_t.constant(0.0f)).cast<int>().sum();
for (int i = 0; i < num_classes; ++i) {
if (denominator_data[i] == 0) {
denominator_data[i] = 1;
}
}
iou_sum_t =
(out_correct_t.cast<float>() / denominator_t.cast<float>()).sum();
out_mean_iou_data[0] += (iou_sum_data[0] / valid_count_data[0]);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/merge_ids_op.h"
namespace paddle {
namespace operators {
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddInput(
"X",
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
"size of embedding table")
.AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num.
split_ids_op -> prefetch_op -> merge_ids_op
merge_ids_op should be used after split_ids_op and prefetch_op, split_ids_op
will split input Ids into multiple tensors according to Id's shard number.
prefetch_op will send them to parameter server to prefetch embedding value
back. During split, the order of ids is disordered. In merge_ids_op we use
the original Ids to restore the order of the fetched embedding value and
also pass the lod information to the merged output.
Example:
Ids = [1,2,3,4,5,6] # 3 shared
split_ids_op ->
Id0 = [3, 6] # id % 3 == 0
Id1 = [1, 4] # id % 3 == 1
Id2 = [2, 5] # id % 3 == 2
prefetch_op ->
X0 = [[0.3 0.3] # 3
[0.6 0.6]] # 6
X1 = [[0.1 0.1] # 1
[0.4 0.4]] # 4
X2 = [[0.2 0.2] # 2
[0.5 0.5]] # 5
merge_ids_op ->
Out = [[0.1 0.1] # 1
[0.2 0.2] # 2
[0.3 0.3] # 3
[0.4 0.4] # 4
[0.5 0.5] # 5
[0.6 0.6]] # 6
)DOC");
}
};
class MergeIdsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) {
PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR,
"input X only support lod tensors");
}
ctx->ShareLoD("Ids", "Out");
}
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
}
};
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(input_var->GetType());
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
ops::MergeIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MergeIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("MergeIds do not support GPU kernel");
}
VLOG(3) << "run in MergeIdsOpKernel";
const auto *ids_var = ctx.InputVar("Ids");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"only support to merge Ids of LoDTensor");
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
const auto &ids_dims = ids_tensor.dims();
const int64_t *ids = ids_tensor.data<int64_t>();
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
auto *out = ctx.Output<framework::LoDTensor>("Out");
int batch_size = 0;
int embedding_size = 0;
for (auto &input : x_tensors) {
if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
}
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
"embedding size of all input should be the same");
batch_size += input->dims()[0];
}
}
PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0],
"the batch size of ids and merged embedding value should be the same");
const size_t shard_num = x_tensors.size();
if (shard_num == 1) {
VLOG(3) << "only one shard, we can copy the data directly";
TensorCopy(*x_tensors[0], place, out);
} else {
std::vector<int> in_indexs(shard_num, 0);
auto *out_data = out->mutable_data<T>(
framework::make_ddim({batch_size, embedding_size}), place);
// copy data from ins[shard_num] to out.
for (int i = 0; i < ids_dims[0]; ++i) {
int64_t id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
int index = in_indexs[shard_id];
memcpy(out_data + embedding_size * i,
x_tensors[shard_id]->data<T>() + index * embedding_size,
sizeof(T) * embedding_size);
in_indexs[shard_id] += 1;
}
for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -322,7 +322,6 @@ class DeviceTracerImpl : public DeviceTracer {
DisableActivity();
dynload::cuptiUnsubscribe(subscriber_);
CUPTI_CALL(dynload::cuptiGetTimestamp(&end_ns_));
PADDLE_ENFORCE(dynload::cuptiFinalize());
enabled_ = false;
}
......
......@@ -72,7 +72,6 @@ extern void *cupti_dso_handle;
__macro(cuptiGetResultString); \
__macro(cuptiActivityGetNumDroppedRecords); \
__macro(cuptiActivityFlushAll); \
__macro(cuptiFinalize); \
__macro(cuptiSubscribe); \
__macro(cuptiUnsubscribe); \
__macro(cuptiEnableCallback); \
......
......@@ -41,6 +41,11 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
}
}
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
// cause blocking problem when a runtime_error was thrown, so try only guard
// NCCL actions when use it.
class NCCLGroupGuard {
public:
static std::mutex &NCCLMutex() {
......
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef MATHFUNCTIONS_H_
#define MATHFUNCTIONS_H_
#pragma once
#ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h>
......@@ -21,7 +20,7 @@ limitations under the License. */
#include <mkl_vml_functions.h>
#endif
#if defined(PADDLE_USE_VECLIB)
#ifdef PADDLE_USE_VECLIB
extern "C" {
#include <cblas.h>
#include <clapack.h>
......@@ -30,8 +29,10 @@ extern "C" {
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#ifdef LAPACK_FOUND
#include <lapacke.h>
#endif
#endif
#ifndef LAPACK_FOUND
extern "C" {
......@@ -126,5 +127,3 @@ template <class T>
void vTanh(const int n, const T* a, T* r);
} // namespace paddle
#endif // MATHFUNCTIONS_H_
......@@ -132,7 +132,8 @@ EOF
-DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \
-DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DWITH_CONTRIB=${WITH_CONTRIB:-ON}
-DWITH_CONTRIB=${WITH_CONTRIB:-ON} \
-DWITH_ANAKIN=ON
}
function abort(){
......
......@@ -25,68 +25,20 @@ import utils
import random
__all__ = [
'fc',
'embedding',
'dynamic_lstm',
'dynamic_lstmp',
'dynamic_gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'cos_sim',
'cross_entropy',
'square_error_cost',
'chunk_eval',
'sequence_conv',
'conv2d',
'sequence_pool',
'sequence_softmax',
'softmax',
'pool2d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
'sequence_expand',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'sequence_first_step',
'sequence_last_step',
'dropout',
'split',
'ctc_greedy_decoder',
'edit_distance',
'l2_normalize',
'matmul',
'topk',
'warpctc',
'sequence_reshape',
'transpose',
'im2sequence',
'nce',
'beam_search',
'row_conv',
'multiplex',
'layer_norm',
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'reshape',
'lod_reset',
'lrn',
'pad',
'label_smooth',
'roi_pool',
'dice_loss',
'image_resize',
'image_resize_short',
'resize_bilinear',
'gather',
'random_crop',
'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru',
'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy',
'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d',
'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'batch_norm',
'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit',
'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'reduce_prod',
'sequence_first_step', 'sequence_last_step', 'dropout', 'split',
'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'topk',
'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', 'nce',
'beam_search', 'row_conv', 'multiplex', 'layer_norm',
'softmax_with_cross_entropy', 'smooth_l1', 'one_hot',
'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad',
'label_smooth', 'roi_pool', 'dice_loss', 'image_resize',
'image_resize_short', 'resize_bilinear', 'gather', 'random_crop', 'mean_iou'
]
......@@ -4279,6 +4231,7 @@ def gather(input, index):
output (Variable): The output is a tensor with the same rank as input.
Examples:
.. code-block:: python
output = fluid.layers.gather(x, index)
......@@ -4337,9 +4290,59 @@ def random_crop(x, shape, seed=None):
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="random_crop",
inputs={"X": input,
inputs={"X": x,
"Seed": seed},
outputs={"Out": out,
"SeedOut": seed_out},
attrs={"shape": shape})
return out
def mean_iou(input, label, num_classes):
"""
Mean Intersection-Over-Union is a common evaluation metric for
semantic image segmentation, which first computes the IOU for each
semantic class and then computes the average over classes.
IOU is defined as follows:
.. math::
IOU = true_positive / (true_positive + false_positive + false_negative).
The predictions are accumulated in a confusion matrix and mean-IOU
is then calculated from it.
Args:
input (Variable): A Tensor of prediction results for semantic labels with type int32 or int64.
label (Variable): A Tensor of ground truth labels with type int32 or int64.
Its shape should be the same as input.
Returns:
mean_iou (Variable): A Tensor representing the mean intersection-over-union with shape [1].
out_wrong(Variable): A Tensor with shape [num_classes]. The wrong numbers of each class.
out_correct(Variable): A Tensor with shape [num_classes]. The correct numbers of each class.
Examples:
.. code-block:: python
iou, wrongs, corrects = fluid.layers.mean_iou(predict, label, num_classes)
"""
helper = LayerHelper('mean_iou', **locals())
dtype = helper.input_dtype()
out_mean_iou = helper.create_tmp_variable(dtype='float32')
out_wrong = helper.create_tmp_variable(dtype='int32')
out_correct = helper.create_tmp_variable(dtype='int32')
helper.append_op(
type="mean_iou",
inputs={"predictions": input,
"labels": label},
outputs={
"out_mean_iou": out_mean_iou,
"out_wrong": out_wrong,
"out_correct": out_correct
},
attrs={"num_classes": num_classes})
return out_mean_iou, out_wrong, out_correct
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
def compute_mean_iou(predictions, labels, num_classes, in_wrongs, in_corrects,
in_mean_ious):
assert predictions.shape == labels.shape
predictions = predictions.flatten()
labels = labels.flatten()
out_wrong = np.zeros([num_classes]).astype("int32")
for _, wrong in in_wrongs:
out_wrong += wrong
out_correct = np.zeros([num_classes]).astype("int32")
for _, correct in in_corrects:
out_correct += correct
for pred, label in zip(predictions, labels):
if pred == label:
out_correct[pred] += 1
else:
out_wrong[pred] += 1
out_wrong[label] += 1
denominator = out_wrong + out_correct
valid_count = (denominator != 0).sum()
denominator = np.where(denominator > 0, denominator,
np.ones(denominator.shape))
mean_iou = (out_correct / denominator).sum() / valid_count
for _, in_mean_iou in in_mean_ious:
mean_iou += in_mean_iou
return mean_iou, out_wrong, out_correct
class TestMeanIOUOp(OpTest):
def setUp(self):
self.config()
self.op_type = "mean_iou"
predictions = np.random.randint(0, self.num_classes,
self.image_size).astype("int32")
labels = np.random.randint(0, self.num_classes,
self.image_size).astype("int32")
in_wrongs = []
for i in range(self.in_wrong_num):
in_wrongs.append(("in_wrong_%d" % i, np.random.randint(
0, 10, [self.num_classes]).astype("int32")))
in_corrects = []
for i in range(self.in_correct_num):
in_corrects.append(("in_correct_%d" % i, np.random.randint(
0, 10, [self.num_classes]).astype("int32")))
in_mean_ious = []
for i in range(self.in_mean_iou_num):
in_mean_ious.append(("in_mean_iou_%d" % i, np.random.uniform(
0, 1, [1]).astype("float32")))
self.inputs = {
'Predictions': predictions,
'Labels': labels,
'InWrongs': in_wrongs,
'InCorrects': in_corrects,
'InMeanIou': in_mean_ious
}
self.attrs = {'num_classes': long(self.num_classes)}
mean_iou, out_wrong, out_correct = compute_mean_iou(
predictions, labels, self.num_classes, in_wrongs, in_corrects,
in_mean_ious)
self.outputs = {
'OutMeanIou': mean_iou,
'OutWrong': out_wrong,
'OutCorrect': out_correct
}
def config(self):
self.num_classes = 10
self.image_size = [128, 128]
self.in_wrong_num = 0
self.in_correct_num = 0
self.in_mean_iou_num = 0
def test_check_output(self):
self.check_output()
class TestCase1(TestMeanIOUOp):
def config(self):
self.num_classes = 5
self.image_size = [100, 128]
self.in_wrong_num = 2
self.in_correct_num = 2
self.in_mean_iou_num = 2
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestMergeIdsOp(OpTest):
def setUp(self):
self.op_type = "merge_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
x1 = np.array([]).astype('float32')
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6],
[0.5, 0.6]]).astype('float32')
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3],
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
......@@ -629,7 +629,7 @@ class DistributeTranspiler:
if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True
op_index = list(all_ops).index(op)
lookup_table_op_index = list(all_ops).index(op)
ids_name = op.input("Ids")
out_name = op.output("Out")
......@@ -649,7 +649,7 @@ class DistributeTranspiler:
# insert split_ids_op
program.global_block().insert_op(
index=op_index,
index=lookup_table_op_index,
type="split_ids",
inputs={
'Ids': [
......@@ -661,7 +661,7 @@ class DistributeTranspiler:
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
index=lookup_table_op_index + 1,
type="prefetch",
inputs={'X': prefetch_input_vars},
outputs={"Out": prefetch_output_vars},
......@@ -672,16 +672,21 @@ class DistributeTranspiler:
# insert concat_op
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': prefetch_output_vars},
index=lookup_table_op_index + 2,
type="merge_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
],
'X': prefetch_output_vars
},
outputs={
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
},
attrs={"axis": 0})
})
# delete lookup_table_op
delete_ops(program.global_block(), [op])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册