提交 d76bda50 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into multithread-sparse-adam

test=develop
......@@ -208,10 +208,10 @@ include(external/xxhash) # download xxhash
include(external/dlpack)
include(external/snappy) # download snappy
include(external/snappystream) # download snappystream
include(external/warpctc) # download, build, install warpctc
if (NOT WIN32)
# there is no official support of warpctc, nccl, cupti in windows
include(external/warpctc) # download, build, install warpctc
# there is no official support of nccl, cupti in windows
include(cupti)
include(external/gzstream)
endif (NOT WIN32)
......
......@@ -26,25 +26,33 @@ SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
# Used in unit test test_WarpCTCLayer
SET(WARPCTC_LIB_DIR "${WARPCTC_INSTALL_DIR}/lib"
CACHE PATH "Warp-ctc Library Directory" FORCE)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
IF(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" )
IF(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR WIN32)
SET(USE_OMP OFF)
ELSE()
SET(USE_OMP ON)
ENDIF()
IF(WIN32)
SET(WARPCTC_REPOSITORY "https://github.com/wopeizl/warp-ctc.git")
ELSE()
SET(WARPCTC_REPOSITORY "https://github.com/dzhwinter/warp-ctc.git")
ENDIF()
ExternalProject_Add(
extern_warpctc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/dzhwinter/warp-ctc.git"
GIT_REPOSITORY ${WARPCTC_REPOSITORY}
PREFIX ${WARPCTC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_OMP=${USE_OMP}
......@@ -59,6 +67,18 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
)
IF(WIN32)
IF(NOT EXISTS "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}")
add_custom_command(TARGET extern_warpctc POST_BUILD
COMMAND cmake -E copy ${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX} ${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}
)
ENDIF()
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
else(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
ENDIF(WIN32)
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers.
......
......@@ -84,7 +84,7 @@ function(op_library TARGET)
endif()
if (WIN32)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op")
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
......
......@@ -350,6 +350,22 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b
paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.contrib.load_persistables_for_increment ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.load_persistables_for_inference ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.convert_dist_to_sparse_program ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.__init__ ArgSpec(args=['self', 'hadoop_home', 'configs'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.delete ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.download ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'unzip'], varargs=None, keywords=None, defaults=(False, False))
paddle.fluid.contrib.HDFSClient.is_dir ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.HDFSClient.is_exist ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.HDFSClient.ls ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.lsr ArgSpec(args=['self', 'hdfs_path', 'only_file', 'sort'], varargs=None, keywords=None, defaults=(True, True))
paddle.fluid.contrib.HDFSClient.make_local_dirs ArgSpec(args=['local_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.makedirs ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.rename ArgSpec(args=['self', 'hdfs_src_path', 'hdfs_dst_path', 'overwrite'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.contrib.HDFSClient.upload ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'retry_times'], varargs=None, keywords=None, defaults=(False, 5))
paddle.fluid.contrib.multi_download ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,))
paddle.fluid.contrib.multi_upload ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True))
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
......
......@@ -359,7 +359,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]);
auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType();
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
out_dtype);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
......@@ -662,13 +664,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node) const {
ir::Node *out_var_node, proto::VarType::Type dtype) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx, dtype);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......
......@@ -68,7 +68,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name,
ir::Node *out_var_node) const;
ir::Node *out_var_node,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
......
......@@ -22,39 +22,66 @@ namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
platform::DeviceContext *dev_ctx,
proto::VarType::Type dtype)
: OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
place_(place),
out_dtype_(dtype) {
this->SetDeviceContext(place_, dev_ctx);
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
struct ScaleLossGradFunctor {
float coeff_;
Tensor *out_;
platform::Place place_;
OpHandleBase *op_handle_;
proto::VarType::Type out_dtype_;
platform::DeviceContext *ctx_;
ScaleLossGradFunctor(float coeff, Tensor *out, platform::Place place,
OpHandleBase *op_handle, proto::VarType::Type dtype,
platform::DeviceContext *ctx)
: coeff_(coeff), out_(out), place_(place), out_dtype_(dtype), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto *out_data = out_->mutable_data<OutT>(place_);
if (platform::is_cpu_place(place_)) {
*out_data = static_cast<OutT>(coeff_);
} else {
#ifdef PADDLE_WITH_CUDA
OutT cast_coeff = static_cast<OutT>(coeff_);
auto stream = static_cast<platform::CUDADeviceContext *>(ctx_)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), out_data,
platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_),
stream);
VLOG(10) << place_ << "RUN Scale loss grad op";
#endif
}
}
};
void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
float *tmp = local_scope.FindVar(var_name)
->GetMutable<LoDTensor>()
->mutable_data<float>(make_ddim({1}), place_);
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1}));
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
#ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] {
auto stream = static_cast<platform::CUDADeviceContext *>(
this->dev_ctxes_.at(place_))
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(10) << place_ << "RUN Scale loss grad op";
});
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_,
this->dev_ctxes_.at(place_));
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); });
#else
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, nullptr);
framework::VisitDataType(out_dtype_, func);
#endif
}
}
std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; }
......
......@@ -26,8 +26,8 @@ namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context);
platform::Place place, platform::DeviceContext *context,
proto::VarType::Type dtype);
~ScaleLossGradOpHandle() final;
......@@ -40,6 +40,7 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_;
Scope *scope_;
platform::Place place_;
proto::VarType::Type out_dtype_;
};
} // namespace details
......
......@@ -24,35 +24,6 @@ namespace paddle {
namespace framework {
namespace ir {
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto from_in_inputs =
std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type;
std::for_each(std::begin(inputs), std::end(inputs),
[from, to, &node](const input_type& i) -> void {
auto param_names = i.second;
auto pi = std::find(std::begin(param_names),
std::end(param_names), from->Name());
if (pi != std::end(param_names)) {
node.Op()->SetInput(i.first, {to->Name()});
}
});
}
}
}
bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
for (auto n : graph->Nodes()) {
......@@ -99,25 +70,12 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
return false;
}
boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
auto bias_input_names = op.Op()->Inputs();
auto bias_it = bias_input_names.find(bias_name);
if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty();
if (has_bias) {
auto bias_names = bias_it->second;
auto bias_names_it =
std::find_if(std::begin(op.inputs), std::end(op.inputs),
[&bias_names](Node* n) -> bool {
return n->Name() == bias_names[0];
});
return *bias_names_it;
}
}
return boost::none;
template <typename T>
boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
if (op.Op()->HasAttr(attr))
return boost::get<T>(op.Op()->GetAttr(attr));
else
return boost::none;
}
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
......@@ -151,40 +109,18 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
auto fuse_relu = HasAttribute<bool>(*conv_op, "fuse_relu");
if (fuse_relu && *fuse_relu) return;
auto conv_bias = HasBias(*conv_op, "Bias");
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
if (conv_bias) {
op_desc.SetInput("Bias", {(*conv_bias)->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(graph, {conv_output, elementwise_add_op});
auto fused_conv_op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (conv_bias) {
IR_NODE_LINK_TO((*conv_bias), fused_conv_op);
}
IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op});
(*fusion_stats)++;
}
......@@ -229,60 +165,33 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
Node* projection_node;
Node* residual_conv_op;
Node* residual_conv_input;
Node* residual_conv_filter;
Node* residual_conv_output;
if (IsReachable(graph, conv_x_input, conv_y_output)) {
projection_node = conv_x_output;
residual_conv_op = conv_y_op;
residual_conv_input = conv_y_input;
residual_conv_filter = conv_y_filter;
residual_conv_output = conv_y_output;
} else if (IsReachable(graph, conv_y_input, conv_x_output)) {
projection_node = conv_y_output;
residual_conv_op = conv_x_op;
residual_conv_input = conv_x_input;
residual_conv_filter = conv_x_filter;
residual_conv_output = conv_x_output;
} else {
return;
}
OpDesc op_desc;
op_desc.SetType("conv2d");
auto fuse_relu = HasAttribute<bool>(*residual_conv_op, "fuse_relu");
if (fuse_relu && *fuse_relu) return;
op_desc.SetInput("Input", {residual_conv_input->Name()});
op_desc.SetInput("Filter", {residual_conv_filter->Name()});
op_desc.SetInput("ResidualData", {projection_node->Name()});
op_desc.SetOutput("Output", {residual_conv_output->Name()});
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
auto residual_conv_bias = HasBias(*residual_conv_op, "Bias");
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
if (residual_conv_bias) {
op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()});
}
for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(graph, {residual_conv_output, elementwise_add_op});
auto fused_conv_op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(residual_conv_input, fused_conv_op);
IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op);
IR_NODE_LINK_TO(projection_node, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, residual_conv_output);
if (residual_conv_bias) {
IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op);
}
IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output);
GraphSafeRemoveNodes(
graph, {elementwise_add_out, residual_conv_op, elementwise_add_op});
(*fusion_stats)++;
}
......
......@@ -16,100 +16,25 @@ limitations under the License. */
#include <functional>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/platform/enforce.h"
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace framework {
static std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
const VariableNameMap& var_map,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = var_map.at(name);
PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s name %s expects one associated var", op->Type(),
name);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]];
} else {
return nullptr;
}
}
static std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, name, op->Inputs(), ngb_node_map);
}
static std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, name, op->Outputs(), ngb_node_map);
}
static void SetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<ngraph::Node> node,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = op->Outputs().at(name);
if (var_names.size() == 1) {
(*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node;
} else {
PADDLE_THROW("name %s has more than 1 var_names.", name);
}
}
static bool HasOutput(const std::shared_ptr<OperatorBase>& op,
const std::string name) {
auto& outputs = op->Outputs();
if (outputs.find(name) == outputs.end()) return false;
return outputs.at(name).size() > 0;
}
template <typename T>
static void BuildBinaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = GetInputNode(op, "X", ngb_node_map);
auto y = GetInputNode(op, "Y", ngb_node_map);
auto out = std::make_shared<T>(x, y);
SetOutputNode(op, "Out", out, ngb_node_map);
}
template <typename T>
static void BuildUnaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = GetInputNode(op, "X", ngb_node_map);
auto out = std::make_shared<T>(input);
SetOutputNode(op, "Out", out, ngb_node_map);
}
std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = {{"relu", BuildUnaryNode<ngraph::op::Relu>},
{"tanh", BuildUnaryNode<ngraph::op::Tanh>}};
NgraphBridge::NG_NODE_MAP = {
{"mul", paddle::operators::ngraphs::BuildMulNode},
{"mul_grad", paddle::operators::ngraphs::BuildMulGradNode},
{"relu", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Relu>},
{"tanh", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Tanh>}};
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
auto& op_type = op->Type();
......
......@@ -278,7 +278,8 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RuntimeInferShape(scope_, place_);
RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
......
......@@ -139,6 +139,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames,
const Scope& scope) {
for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name));
}
}
for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name));
}
}
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
......@@ -414,11 +431,48 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
return var != nullptr;
}
const Variable* ExecutionContext::InputVar(const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's input %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
const Variable* ExecutionContext::LegacyInputVar(
const std::string& name) const {
auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
Variable* ExecutionContext::OutputVar(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's output %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
Variable* ExecutionContext::LegacyOutputVar(const std::string& name) const {
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
return Input<LoDTensor>(name);
}
template <>
const Tensor* ExecutionContext::LegacyInput<Tensor>(
const std::string& name) const {
return LegacyInput<LoDTensor>(name);
}
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const {
......@@ -443,6 +497,11 @@ Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
return Output<LoDTensor>(name);
}
template <>
Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const {
return LegacyOutput<LoDTensor>(name);
}
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const {
......@@ -479,23 +538,22 @@ bool OpSupportGPU(const std::string& op_type) {
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope,
const RuntimeContext& ctx)
: op_(op), scope_(scope), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = op_.Inputs();
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0 || in[0] == kEmptyVarName) {
return false;
}
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(in.size(), 1UL,
"Input %s should not have more than one inputs", name);
return scope_.FindVar(in[0]) != nullptr;
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
......@@ -680,6 +738,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
private:
const OperatorBase& op_;
const Scope& scope_;
const RuntimeContext& ctx_;
};
static void CheckTensorNANOrInf(const std::string& name,
......@@ -698,15 +757,15 @@ static void CheckTensorNANOrInf(const std::string& name,
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
const platform::Place& place,
const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx);
this->InferShape(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -720,15 +779,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -750,7 +802,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope =
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope =
......@@ -760,7 +812,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
......@@ -784,6 +840,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
}
void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const {
......@@ -799,13 +856,19 @@ void OperatorWithKernel::TransferInplaceVarsBack(
}
}
Scope* OperatorWithKernel::TryTransferData(
Scope* OperatorWithKernel::PrepareData(
const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const {
std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const {
Scope* new_scope = nullptr;
for (auto& var_name_item : Inputs()) {
for (auto& var_name : var_name_item.second) {
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name);
input_vars[i] = var;
// Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) {
continue;
......@@ -853,6 +916,7 @@ Scope* OperatorWithKernel::TryTransferData(
}
auto* trans_var = new_scope->Var(var_name);
input_vars[i] = trans_var;
Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
......
......@@ -73,6 +73,15 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
class OperatorBase;
class ExecutionContext;
class RuntimeContext {
public:
RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames, const Scope& scope);
VariableValueMap inputs;
VariableValueMap outputs;
};
/**
* OperatorBase has the basic elements that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -132,7 +141,8 @@ class OperatorBase {
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {}
const platform::Place& place,
const RuntimeContext& ctx) const {}
protected:
std::string type_;
......@@ -159,8 +169,9 @@ class OperatorBase {
class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const platform::DeviceContext& device_context,
const RuntimeContext& ctx)
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
const OperatorBase& op() const { return op_; }
......@@ -183,15 +194,9 @@ class ExecutionContext {
return op_.Outputs(name).size();
}
const Variable* InputVar(const std::string& name) const {
auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
const Variable* InputVar(const std::string& name) const;
Variable* OutputVar(const std::string& name) const {
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
Variable* OutputVar(const std::string& name) const;
const std::vector<const Variable*> MultiInputVar(
const std::string& name) const {
......@@ -230,6 +235,22 @@ class ExecutionContext {
return var == nullptr ? nullptr : var->GetMutable<T>();
}
template <typename T>
const T* LegacyInput(const std::string& name) const {
auto* var = LegacyInputVar(name);
return var == nullptr ? nullptr : &var->Get<T>();
}
template <typename T>
T* LegacyOutput(const std::string& name) const {
auto var = LegacyOutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<T>();
}
const Variable* LegacyInputVar(const std::string& name) const;
Variable* LegacyOutputVar(const std::string& name) const;
template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const {
auto names = op_.Inputs(name);
......@@ -289,11 +310,16 @@ class ExecutionContext {
const OperatorBase& op_;
const Scope& scope_;
const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_;
};
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
template <>
const Tensor* ExecutionContext::LegacyInput<Tensor>(
const std::string& name) const;
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const;
......@@ -301,6 +327,9 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
......@@ -353,8 +382,8 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
}
void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const override;
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
......@@ -374,9 +403,10 @@ class OperatorWithKernel : public OperatorBase {
*
* * transfered_inplace_vars is a output vector.
*/
Scope* TryTransferData(
const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const;
Scope* PrepareData(const Scope& scope,
const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const;
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
......
......@@ -28,8 +28,11 @@ class OperatorBase;
class OpDesc;
class InferShapeContext;
class BlockDesc;
class Variable;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// TODO(panyx0718): Replace vector with something like gtl::Vector.
using VariableValueMap = std::map<std::string, std::vector<Variable*>>;
// The order should be as same as framework.proto
using Attribute =
......
......@@ -188,11 +188,13 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
std::vector<Variable*> ret;
for (size_t i = 0; i < input_vars_->size(); ++i) {
bool found = false;
VarBase* origin_var = (*input_vars_)[i];
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
Variable* var = scope->FindVar(outvar);
VarBase* origin_var = (*input_vars_)[i];
std::string orig_var = grad_to_var_->at(outvar);
PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var);
if (origin_var->var_desc_->Name() != orig_var) {
continue;
}
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var;
origin_var->ApplyGrad(scope, var);
found = true;
......
......@@ -43,9 +43,12 @@ void CreateGradOp(const framework::OpDesc& op_desc,
class Tracer {
public:
explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
explicit Tracer(framework::BlockDesc* root_block,
framework::BlockDesc* startup_block)
: root_block_(root_block), startup_block_(startup_block) {
root_scope_ = new framework::Scope();
scopes_[root_block_] = root_scope_;
scopes_[startup_block_] = root_scope_;
}
virtual ~Tracer() { delete root_scope_; }
......@@ -80,6 +83,8 @@ class Tracer {
} else {
op->pre_ops_->push_back(nullptr);
}
VLOG(3) << "input vname " << vname << " "
<< var->Get<framework::LoDTensor>().dims().size();
}
*op->output_vars_ = outputs;
......@@ -98,12 +103,19 @@ class Tracer {
outputs[i]->pre_op_ = op;
outputs[i]->pre_op_out_idx_ = i;
}
VLOG(3) << "tracer running " << op_desc->Type();
op_base->Run(*scope, platform::CPUPlace());
framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
if (block == startup_block_) {
op->grad_op_desc_ = nullptr;
op->grad_to_var_ = nullptr;
} else {
framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
}
op->block_ = block;
}
......@@ -121,6 +133,7 @@ class Tracer {
private:
std::map<framework::BlockDesc*, framework::Scope*> scopes_;
framework::BlockDesc* root_block_;
framework::BlockDesc* startup_block_;
framework::Scope* root_scope_;
};
......
......@@ -64,9 +64,7 @@ endif()
set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
if (NOT WIN32)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
if (WITH_GPU)
......
......@@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
static constexpr char kInputs[] = "inputs";
static constexpr char kParameters[] = "parameters";
static constexpr char kPlaces[] = "places";
static constexpr char kOutputs[] = "outputs";
static constexpr char kParallelScopes[] = "parallel_scopes";
static constexpr char kParallelBlock[] = "sub_block";
static constexpr char kUseNCCL[] = "use_nccl";
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
static void SplitTensorAndMoveTensorToScopes(
const framework::Scope &scope, std::vector<framework::Scope *> *sub_scopes,
const std::vector<platform::Place> &places,
const std::vector<std::string> &names) {
size_t num_sub_scopes = 0;
for (auto &argu : names) {
const auto &tensor =
detail::Ref(scope.FindVar(argu),
"Cannot find variable %s in the parent scope", argu)
.Get<LoDTensor>();
auto lod_tensors = tensor.SplitLoDTensor(places);
for (auto &lod : lod_tensors) {
VLOG(3) << lod.dims();
}
if (num_sub_scopes == 0) {
num_sub_scopes = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(num_sub_scopes, lod_tensors.size());
}
PADDLE_ENFORCE_NE(num_sub_scopes, 0);
if (sub_scopes->size() == 0) {
sub_scopes->reserve(num_sub_scopes);
for (size_t i = 0; i < num_sub_scopes; ++i) {
sub_scopes->emplace_back(&scope.NewScope());
}
}
for (size_t i = 0; i < lod_tensors.size(); ++i) {
*detail::Ref(sub_scopes->at(i)->Var(argu),
"Cannot find variable in the sub-scope", argu)
.GetMutable<LoDTensor>() = lod_tensors[i];
}
}
}
inline void CopyOrShare(const framework::Variable &src,
const platform::Place &dst_place,
framework::Variable *dst) {
if (src.IsType<LoDTensor>()) {
if (src.Get<LoDTensor>().place() == dst_place) {
dst->GetMutable<LoDTensor>()->ShareDataWith(src.Get<LoDTensor>());
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else {
TensorCopy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
}
} else if (src.IsType<SelectedRows>()) {
auto &src_sr = src.Get<SelectedRows>();
auto *dst_sr = dst->GetMutable<SelectedRows>();
dst_sr->set_height(src_sr.height());
if (src_sr.value().place() == dst_place) {
dst_sr->mutable_value()->ShareDataWith(src_sr.value());
dst_sr->set_rows(src_sr.rows());
} else {
TensorCopy(src_sr.value(), dst_place, dst_sr->mutable_value());
}
} else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
}
}
void WaitOnPlace(const platform::Place place) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
dev_ctx.Wait();
}
void WaitOnPlaces(const std::vector<platform::Place> places) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
for (auto &place : places) {
auto &dev_ctx = *pool.Get(place);
dev_ctx.Wait();
}
}
class ParallelDoOp : public framework::OperatorBase {
public:
ParallelDoOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program();
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
auto &sub_scopes = *scope.FindVar(Output(kParallelScopes))
->GetMutable<std::vector<framework::Scope *>>();
// split input
SplitTensorAndMoveTensorToScopes(scope, &sub_scopes, places,
Inputs(kInputs));
// copy parameter
for (auto &param : Inputs(kParameters)) {
PADDLE_ENFORCE(scope.FindVar(param)->IsType<LoDTensor>(),
"Only support parameter type as LoDTensor");
auto &src = scope.FindVar(param)->Get<LoDTensor>();
auto *sub_scope0 = sub_scopes[0];
auto *dst0 = sub_scope0->Var(param)->GetMutable<LoDTensor>();
dst0->ShareDataWith(src);
for (size_t i = 1; i < sub_scopes.size(); ++i) {
auto &place = places[i];
auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::TensorCopy(src, place, dst);
}
}
WaitOnPlaces(places);
std::vector<std::future<void>> workers;
workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) {
auto &place = places[place_idx];
auto *cur_scope = sub_scopes[place_idx];
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/);
}));
}
for (auto &worker : workers) {
worker.wait();
}
WaitOnPlaces(places);
// merge output
for (auto &o_name : Outputs(kOutputs)) {
std::vector<const framework::LoDTensor *> lod_tensors;
lod_tensors.reserve(sub_scopes.size());
for (auto *sub_scope : sub_scopes) {
lod_tensors.emplace_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>());
}
auto *lod_tensor_to_be_merged =
scope.FindVar(o_name)->GetMutable<LoDTensor>();
lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace());
}
WaitOnPlaces(places);
}
};
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kInputs, "").AsDuplicable();
AddInput(kParameters, "").AsDuplicable();
AddInput(kPlaces, "");
AddOutput(kOutputs, "").AsDuplicable();
AddOutput(kParallelScopes, "");
AddAttr<framework::BlockDesc *>(kParallelBlock, "");
AddAttr<bool>(kUseNCCL, "true if we use nccl on backward")
.SetDefault(false);
AddComment(R"DOC(
ParallelDo Operator.
)DOC");
}
};
class ParallelDoGradOp : public framework::OperatorBase {
public:
ParallelDoGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *block = Attr<framework::BlockDesc *>(kParallelBlock);
auto *program = block->Program();
auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
// feed output@grad
SplitTensorAndMoveTensorToScopes(
scope, const_cast<std::vector<framework::Scope *> *>(&sub_scopes),
places, Inputs(framework::GradVarName(kOutputs)));
WaitOnPlaces(places);
// exe run
std::vector<std::future<void>> workers;
for (size_t i = 0; i < sub_scopes.size(); ++i) {
auto &place = places[i];
auto *cur_scope = sub_scopes[i];
// execute
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/);
}));
}
for (auto &worker : workers) {
worker.wait();
}
WaitOnPlaces(places);
// NCCL allreduce op will be added by backward,
// so no need to explicitly accumulate grad
if (!(Attr<bool>(kUseNCCL))) {
AccumulateGrad(scope, place, sub_scopes, places);
} else {
for (auto &place : places) {
PADDLE_ENFORCE(platform::is_gpu_place(place),
"NCCL only supports cuda place");
}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == framework::kEmptyVarName) {
continue;
}
VLOG(3) << "Moving " << s;
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
}
void AccumulateGrad(const framework::Scope &scope,
const platform::Place &place,
const std::vector<framework::Scope *> &sub_scopes,
const platform::PlaceList &places) const {
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == framework::kEmptyVarName) {
continue;
}
VLOG(3) << "Accumulating " << s;
if (s == framework::kEmptyVarName) continue;
std::string tmp_name;
auto *tmp = sub_scopes[0]->Var(&tmp_name);
for (size_t i = 1; i < sub_scopes.size(); ++i) {
CopyOrShare(*sub_scopes[i]->FindVar(s), places[0], tmp);
WaitOnPlaces(places);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]);
sum_op->Run(*sub_scopes[0], places[0]);
WaitOnPlace(places[0]);
}
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
}
};
std::ostream &operator<<(std::ostream &sout,
const std::vector<std::string> &strs) {
std::copy(strs.begin(), strs.end(),
std::ostream_iterator<std::string>(sout, ","));
return sout;
}
class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<framework::OpDesc> Apply() const {
auto *grad = new framework::OpDesc();
grad->SetType("parallel_do_grad");
for (auto &input_param : this->InputNames()) {
VLOG(3) << input_param;
grad->SetInput(input_param, this->Input(input_param));
if (input_param != kPlaces) {
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param, false));
}
}
auto *g_block = this->grad_block_[0];
// All variable name that needed by gradient operators
std::unordered_set<std::string> all_inputs_in_grad_blocks;
for (size_t i = 0; i < g_block->OpSize(); ++i) {
auto *op = g_block->Op(i);
for (auto &var_name : op->InputArgumentNames()) {
all_inputs_in_grad_blocks.insert(var_name);
}
}
for (auto &output_param : this->OutputNames()) {
if (output_param == kParallelScopes) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param));
} else {
grad->SetInput(output_param, this->Output(output_param));
std::vector<std::string> og_names;
for (auto &og_name : this->OutputGrad(output_param)) {
if (all_inputs_in_grad_blocks.count(og_name) != 0) {
// there are some gradient operators who need the OG. So make this
// OG as an input of parallel.do
og_names.push_back(og_name);
}
// else, there is no operator who need the OG. Do not use this OG as
// an input
}
grad->SetInput(framework::GradVarName(output_param), og_names);
}
}
grad->SetInput("Communicator", {"nccl_com__do_not_change_"});
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, grad_block_[0]);
return std::unique_ptr<framework::OpDesc>(grad);
}
};
class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs(kParameters));
PADDLE_ENFORCE(ctx->HasInputs(kInputs));
PADDLE_ENFORCE(ctx->HasInputs(kOutputs));
ctx->SetOutputsDim(framework::GradVarName(kParameters),
ctx->GetInputsDim(kParameters));
auto i_dims = ctx->GetInputsDim(kInputs);
auto ig_names = ctx->Outputs(framework::GradVarName(kInputs));
for (size_t i = 0; i < ig_names.size(); ++i) {
auto &ig_name = ig_names[i];
if (ig_name == framework::kEmptyVarName) {
continue;
}
ctx->SetDims({ig_name}, {i_dims[i]});
}
auto p_dims = ctx->GetInputsDim(kParameters);
auto pg_names = ctx->Outputs(framework::GradVarName(kParameters));
for (size_t i = 0; i < pg_names.size(); ++i) {
auto &pg_name = pg_names[i];
if (pg_name == framework::kEmptyVarName) {
continue;
}
ctx->SetDims({pg_name}, {p_dims[i]});
}
}
};
class ParallelDoGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
framework::BlockDesc *sub_block =
boost::get<framework::BlockDesc *>(op_desc.GetAttr(kParallelBlock));
for (auto &out_vars : op_desc.Outputs()) {
for (auto &out_var : out_vars.second) {
auto &var = block->FindRecursiveOrCreateVar(out_var);
auto sub_var = sub_block->FindRecursiveOrCreateVar(out_var);
if (sub_var.GetType() != var.GetType()) {
var.SetType(sub_var.GetType());
}
}
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp,
paddle::operators::ParallelDoOpProtoMaker,
paddle::operators::ParallelDoGradOpDescMaker);
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp,
paddle::operators::ParallelDoGradOpShapeInference,
paddle::operators::ParallelDoGradOpVarTypeInference);
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <nccl.h>
#endif
#include <sys/time.h>
#include <limits>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
......@@ -31,7 +32,12 @@ namespace distributed {
class IOBufWriter {
public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) {
static void Append(const std::string& varname, butil::IOBuf* iobuf, int k,
const char* v, int64_t vlen) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
}
iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen);
......@@ -87,6 +93,10 @@ class IOBufWriter {
int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
}
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data);
......@@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf,
IOBufWriter::Append(name, iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES);
return;
......@@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
// FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) {
IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size());
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
......@@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber,
IOBufWriter::Append(name, iobuf,
::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size));
}
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <limits>
#include "glog/logging.h" // For VLOG
......@@ -420,7 +421,15 @@ void GRPCClient::Proceed() {
sync_cond_.notify_all();
}
}
VLOG(3) << "GRPCClient Proceed end";
// Last log message
// Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
// static Mutex log_mutex is used for synchronization, which might have been
// destructed at this moment.
if (FLAGS_v >= 3) {
std::string msg("GRPCClient Proceed end");
fwrite(msg.c_str(), msg.length(), 1, stdout);
}
}
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <limits>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
......@@ -102,6 +103,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) {
LOG(FATAL) << "AppendZeroCopy varname:" << name
<< ", vlen:" << payload->memory_size();
}
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
......@@ -115,7 +120,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <typeindex>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
......@@ -23,9 +24,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace operators {
......@@ -83,6 +83,11 @@ inline framework::proto::VarType::Type ToVarType(
}
}
template <template <typename> class T, typename Elem>
std::string VectorElemName(const T<Elem>& arg) {
return typeid(Elem).name();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -118,7 +118,7 @@ bool VariableResponse::CopyLodTensorData(
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length;
PADDLE_ENFORCE_EQ(tensor->memory_size(), length);
PADDLE_ENFORCE_EQ(tensor->memory_size(), static_cast<unsigned int>(length));
return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
}
......
......@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
......
......@@ -12,19 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>);
elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......@@ -22,4 +23,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/reduce.h>
#include "paddle/fluid/operators/metrics/accuracy_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
......@@ -94,6 +95,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): types of T is for inference data.
// label data is always int64
REGISTER_OP_CUDA_KERNEL(accuracy,
paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>,
paddle::operators::AccuracyOpCUDAKernel<paddle::platform::float16>);
......@@ -49,7 +49,8 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
"The input tensor Y's rank of MulOp should be larger than "
"y_num_col_dims.");
"y_num_col_dims: %ld vs %ld",
y_dims.size(), y_num_col_dims);
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains the list of the ngraph operators for Paddle.
*
* ATTENTION: It requires some C++11 features, for lower version C++ or C, we
* might release another API.
*/
#pragma once
#include "ops/binary_unnary_op.h"
#include "ops/mul_op.h"
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
template <typename T>
static void BuildBinaryNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto out = std::make_shared<T>(x, y);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
template <typename T>
static void BuildUnaryNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto out = std::make_shared<T>(input);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
#endif
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
static void BuildMulNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
int x_num_col_dims = op_attrs.Get<int>("x_num_col_dims");
int y_num_col_dims = op_attrs.Get<int>("y_num_col_dims");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto x_reshape = x;
auto y_reshape = y;
if (x->get_shape().size() > 2) {
auto x_2d = paddle::platform::FlattenTo2d(x->get_shape(), x_num_col_dims);
x_reshape = paddle::platform::NgReshaper(x, x_2d);
}
if (y->get_shape().size() > 2) {
auto y_2d = paddle::platform::FlattenTo2d(y->get_shape(), y_num_col_dims);
y_reshape = paddle::platform::NgReshaper(y, y_2d);
}
std::shared_ptr<ngraph::Node> out =
std::make_shared<ngraph::op::Dot>(x_reshape, y_reshape);
auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map);
if (dummy_out && dummy_out->get_shape() != out->get_shape()) {
out = paddle::platform::NgReshaper(out, dummy_out->get_shape());
}
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
static void BuildMulGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
int x_num_col_dims = op_attrs.Get<int>("x_num_col_dims");
int y_num_col_dims = op_attrs.Get<int>("y_num_col_dims");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
bool is_dx = paddle::platform::HasOutput(op, "X@GRAD") ? true : false;
bool is_dy = paddle::platform::HasOutput(op, "Y@GRAD") ? true : false;
auto x_shape = x->get_shape();
auto y_shape = y->get_shape();
auto x_reshape = x;
auto y_reshape = y;
if (x_shape.size() > 2) {
auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, x_num_col_dims);
x_reshape = paddle::platform::NgReshaper(x, x_2d_shape);
}
if (y_shape.size() > 2) {
auto y_2d_shape = paddle::platform::FlattenTo2d(y_shape, y_num_col_dims);
y_reshape = paddle::platform::NgReshaper(y, y_2d_shape);
}
auto x_reshape_shape = x_reshape->get_shape();
std::reverse(x_reshape_shape.begin(), x_reshape_shape.end());
auto x_transpose = std::make_shared<ngraph::op::Reshape>(
x_reshape, ngraph::AxisVector{1, 0}, x_reshape_shape);
auto y_reshape_shape = y_reshape->get_shape();
std::reverse(y_reshape_shape.begin(), y_reshape_shape.end());
auto y_transpose = std::make_shared<ngraph::op::Reshape>(
y_reshape, ngraph::AxisVector{1, 0}, y_reshape_shape);
if (is_dx) {
if (dout->get_shape().size() > 2) {
auto dout_2d_shape = paddle::platform::FlattenTo2d(dout->get_shape(), 2);
dout = paddle::platform::NgReshaper(dout, dout_2d_shape);
}
auto dx = std::make_shared<ngraph::op::Dot>(dout, y_transpose);
if (dx->get_shape() == x_shape) {
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
} else {
auto dx_reshape = paddle::platform::NgReshaper(dx, x_shape);
paddle::platform::SetOutputNode(op, "X@GRAD", dx_reshape, ngb_node_map);
}
}
if (is_dy) {
if (dout->get_shape().size() > 2) {
auto dout_2d_shape = paddle::platform::FlattenTo2d(dout->get_shape(), 2);
dout = paddle::platform::NgReshaper(dout, dout_2d_shape);
}
auto dy = std::make_shared<ngraph::op::Dot>(x_transpose, dout);
if (dy->get_shape() == y_shape) {
paddle::platform::SetOutputNode(op, "Y@GRAD", dy, ngb_node_map);
} else {
auto dy_reshape = paddle::platform::NgReshaper(dy, y_shape);
paddle::platform::SetOutputNode(op, "Y@GRAD", dy_reshape, ngb_node_map);
}
}
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
#endif
......@@ -14,8 +14,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -237,7 +237,8 @@ class SparseMomentumFunctor<T, UseNesterov> {
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
......@@ -282,7 +283,8 @@ class SparseMomentumFunctor<T, NoNesterov> {
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -150,7 +151,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-INFINITY, -1);
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
......@@ -160,7 +161,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
}
*max = topk[MaxLength - 1];
if ((*max).v == -1) *is_empty = true;
if ((*max).v == -static_cast<T>(1)) *is_empty = true;
*beam = 0;
}
}
......@@ -181,7 +182,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
if (k < MaxLength - *beam) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-INFINITY, -1);
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
......@@ -278,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
bool firststep = true;
for (int j = 0; j < MaxLength; j++) {
topk[j].set(-INFINITY, -1);
topk[j].set(-static_cast<T>(INFINITY), -1);
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(
......@@ -362,5 +363,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::DataLayout;
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE(
is_test == true,
"ConvTransposeMKLDNN works only for inference!. Set is_test = True");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
const T* input_data = input->data<T>();
if (ndims == 1) {
output->ShareDataWith(*input);
return;
}
std::vector<int> nchw_axis(ndims, 0);
for (size_t i = 0; i < nchw_axis.size(); ++i) {
nchw_axis[i] = i;
}
std::vector<int> nchw_tz = paddle::framework::vectorize2int(input->dims());
std::string data_format = ctx.Attr<std::string>("data_format");
auto src_md =
input->format() != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(nchw_tz, platform::MKLDNNGetDataType<T>(),
input->format())
: Axis2MemoryDesc(nchw_tz, nchw_axis);
this->TransposeKernel(ctx.GetPlace(), Axis2MemoryDesc(nchw_tz, axis),
src_md, output, input_data, nchw_tz, mkldnn_engine);
}
protected:
mkldnn::memory::desc Axis2MemoryDesc(std::vector<int>& nchw_tz,
std::vector<int>& axis) const {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
unsigned int total_stride = 1;
for (int i = nchw_tz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1;
total_stride *= nchw_tz[axis[i]];
}
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
return mem_fmt;
}
void TransposeKernel(platform::Place place, mkldnn::memory::desc md_o,
mkldnn::memory::desc md_i, Tensor* output,
const T* data_i, std::vector<int>& nchw_dims,
const mkldnn::engine& eng) const {
// Make Memory primitive descriptors
auto mpd_o = mkldnn::memory::primitive_desc(md_o, eng);
auto mpd_i = mkldnn::memory::primitive_desc(md_i, eng);
auto data_o = output->mutable_data<T>(
place, paddle::memory::Allocator::kDefault, mpd_o.get_size());
auto src = mkldnn::memory(mpd_i, (T*)(data_i));
auto dst = mkldnn::memory(mpd_o, data_o);
auto r = mkldnn::reorder(src, dst);
mkldnn::stream(mkldnn::stream::kind::eager).submit({r}).wait();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>);
......@@ -16,6 +16,10 @@ limitations under the License. */
#include <string>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -53,11 +57,32 @@ class TransposeOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim("Out", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
}
};
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput(
"X",
"(Tensor) The input tensor, tensors with rank up to 6 are supported.");
......@@ -67,6 +92,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) A list of values, and the size of the list should be "
"the same with the input tensor rank. This operator permutes the input "
"tensor's axes according to the values given.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC(
Transpose Operator.
......@@ -144,8 +179,18 @@ class Transpose2Op : public TransposeOp {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
}
};
......
......@@ -16,9 +16,7 @@ if (CUPTI_FOUND)
list(APPEND CUDA_SRCS cupti.cc)
endif(CUPTI_FOUND)
nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader)
if (NOT WIN32)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
endif(NOT WIN32)
if (WITH_MKLML)
cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml)
endif()
......
......@@ -34,7 +34,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using cudnn_func = decltype(&::__name); \
std::call_once(cudnn_dso_flag, []() { \
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
......
......@@ -201,6 +201,8 @@ void* GetCurandDsoHandle() {
void* GetWarpCTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "warpctc.dll");
#else
return GetDsoHandleFromSearchPath(FLAGS_warpctc_dir, "libwarpctc.so");
#endif
......
......@@ -18,6 +18,12 @@ namespace paddle {
namespace platform {
namespace dynload {
#ifndef _WIN32
#define DECLARE_TYPE(__name, ...) decltype(__name(__VA_ARGS__))
#else
#define DECLARE_TYPE(__name, ...) decltype(auto)
#endif
void* GetCublasDsoHandle();
void* GetCUDNNDsoHandle();
void* GetCUPTIDsoHandle();
......
......@@ -34,7 +34,7 @@ extern void* mklml_dso_handle;
#define DYNAMIC_LOAD_MKLML_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using mklmlFunc = decltype(&::__name); \
std::call_once(mklml_dso_flag, []() { \
mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \
......
......@@ -33,7 +33,7 @@ extern void* tensorrt_dso_handle;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = \
......
......@@ -34,7 +34,7 @@ extern void* warpctc_dso_handle;
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using warpctcFunc = decltype(&::__name); \
std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#define NCCL_ID_VARNAME "NCCLID"
......@@ -38,6 +39,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclInt;
} else if (type == framework::proto::VarType::INT64) {
return ncclInt64;
} else if (type == framework::proto::VarType::FP16) {
return ncclFloat16;
} else {
PADDLE_THROW("Not supported");
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <functional>
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace platform {
static ngraph::Shape FlattenTo2d(ngraph::Shape sh, int num) {
auto x1 = std::accumulate(std::begin(sh), std::begin(sh) + num, 1,
std::multiplies<size_t>());
auto x2 = std::accumulate(std::begin(sh) + num, std::end(sh), 1,
std::multiplies<size_t>());
size_t x1_l = static_cast<size_t>(x1);
size_t x2_l = static_cast<size_t>(x2);
return ngraph::Shape{x1_l, x2_l};
}
static std::shared_ptr<ngraph::Node> NgReshaper(
std::shared_ptr<ngraph::Node> input, ngraph::Shape shape) {
std::vector<size_t> input_order(input->get_shape().size());
std::iota(std::begin(input_order), std::end(input_order), 0);
return std::make_shared<ngraph::op::Reshape>(
input, ngraph::AxisVector(input_order), shape);
}
static std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const paddle::framework::VariableNameMap& var_map,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = var_map.at(prm);
PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s prm %s expects one associated var", op->Type(), prm);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]];
} else {
return nullptr;
}
}
static std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, prm, op->Inputs(), ngb_node_map);
}
static std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, prm, op->Outputs(), ngb_node_map);
}
static void SetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, std::shared_ptr<ngraph::Node> node,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = op->Outputs().at(prm);
if (var_names.size() == 1) {
(*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node;
} else {
PADDLE_THROW("prm %s has more than 1 var_names.", prm);
}
}
static bool HasOutput(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm) {
auto& outputs = op->Outputs();
if (outputs.find(prm) == outputs.end()) return false;
return outputs.at(prm).size() > 0;
}
} // namespace platform
} // namespace paddle
#endif
......@@ -55,7 +55,6 @@ static void *dlsym(void *handle, const char *symbol_name) {
static void *dlopen(const char *filename, int flag) {
std::string file_name(filename);
file_name.replace(0, file_name.size() - 1, '/', '\\');
HMODULE hModule = LoadLibrary(file_name.c_str());
if (!hModule) {
throw std::runtime_error(file_name + " not found.");
......
......@@ -24,8 +24,9 @@ namespace pybind {
void BindTracer(pybind11::module *m) {
pybind11::class_<imperative::Tracer>(*m, "Tracer", "")
.def("__init__",
[](imperative::Tracer &self, framework::BlockDesc *root_block) {
new (&self) imperative::Tracer(root_block);
[](imperative::Tracer &self, framework::BlockDesc *root_block,
framework::BlockDesc *startup_block) {
new (&self) imperative::Tracer(root_block, startup_block);
})
.def("trace", &imperative::Tracer::Trace)
.def("get_scope", &imperative::Tracer::GetScope,
......
......@@ -509,11 +509,11 @@ function assert_api_spec_approvals() {
if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then
# NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable.
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433`
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have at least 2 approvals for the api change! ${API_FILE}"
exit 1
echo "You must have panyx0718 approval for the api change! ${API_FILE}"
exit 1
fi
fi
done
......@@ -521,11 +521,11 @@ function assert_api_spec_approvals() {
HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true`
if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433`
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have at least 2 approvals for the const_cast"
exit 1
echo "You must have panyx0718 approval for the const_cast"
exit 1
fi
fi
......
......@@ -102,6 +102,13 @@ def __bootstrap__():
import sys
import os
import platform
if os.name == 'nt':
third_lib_path = os.path.abspath(os.path.dirname(
__file__)) + os.sep + '..' + os.sep + 'libs'
os.environ['path'] += ';' + third_lib_path
sys.path.append(third_lib_path)
from . import core
in_test = 'unittest' in sys.modules
......@@ -128,14 +135,14 @@ def __bootstrap__():
'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size",
'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir', 'pe_profile_fname', 'inner_op_parallelism',
'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir',
'inner_op_parallelism'
'min_param_size_to_use_multithread'
]
if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory')
if os.name != 'nt':
read_env_flags.append('warpctc_dir')
read_env_flags.append('cpu_deterministic')
if core.is_compiled_with_dist():
......
......@@ -249,69 +249,6 @@ def serialize_op_decs(op_desc):
return proto.__str__()
def _callback_lookup_(op):
"""
Only used in _append_backward_ops_
Build and returns a callback function for certain op. For example
parallel_do: AllReduce
:param op:
:return: callback function
"""
if op.type == 'parallel_do' and op.attr('use_nccl'):
all_vars = op.block.vars
param_names = set(op.input('parameters'))
param_names = [
name for name in param_names
if all_vars[name].stop_gradient is False
]
param_grad_names = [n + "@GRAD" for n in param_names]
class ParallelDoCallBack(object):
def __init__(self, param_grad_names, parallel_scopes_name):
self.has_inserted_nccl_init = False
self.param_grad_names = param_grad_names
self.parallel_scopes_name = parallel_scopes_name
def __call__(self, block, context):
if not self.has_inserted_nccl_init:
op_desc = _create_op_desc_(
"ncclInit",
{"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {})
block.program.global_block().desc.append_op().copy_from(
op_desc)
self.has_inserted_nccl_init = True
current_op_desc = context["__current_op_desc__"]
for o_param in current_op_desc.output_names():
for o_argu in current_op_desc.output(o_param):
if o_argu in self.param_grad_names:
allreduce_out_name = o_argu + "__nccl_all_reduce__"
op_desc = _create_op_desc_(
"ncclReduce",
{
"X": [o_argu],
"Communicator":
['nccl_com__do_not_change_']
},
{"Out": [allreduce_out_name]},
{"reduction": "ncclSum",
"root": 0}, )
block.desc.append_op().copy_from(op_desc)
op_desc = _create_op_desc_(
"assign", {"X": [allreduce_out_name]},
{"Out": [o_argu]}, {})
block.desc.append_op().copy_from(op_desc)
return ParallelDoCallBack(param_grad_names,
op.output("parallel_scopes"))
else:
return None
def _append_backward_ops_(block,
ops,
target_block,
......@@ -349,17 +286,8 @@ def _append_backward_ops_(block,
sub_block = program.block(op._block_attr_id("sub_block"))
grad_sub_block = program._create_block()
grad_sub_block._set_forward_block_idx(sub_block.idx)
cb = _callback_lookup_(op)
if cb is not None:
if callbacks is None:
new_callbacks = [cb]
else:
new_callbacks = callbacks + [_callback_lookup_(op)]
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, new_callbacks)
else:
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, callbacks)
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, callbacks)
program._rollback()
grad_sub_block_list.append(grad_sub_block.desc)
......@@ -424,9 +352,6 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
# infer_shape and infer_type
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
# ncclInit dones't need to set data_type
if op_desc.type() == 'ncclInit':
continue
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_(arg, block)
......@@ -564,8 +489,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
grad_to_var = dict()
op_desc = _create_op_desc_(
"fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
"shape": [1],
"fill_constant",
{},
{"Out": [_append_grad_suffix_(loss.name)]},
{
"shape": [1], # TODO(panyx0718): This can be loss.shape.
"value": 1.0,
"dtype": loss.dtype,
"force_cpu": False,
......
......@@ -22,9 +22,12 @@ from . import op_frequence
from .op_frequence import *
from . import quantize
from .quantize import *
from . import utils
from .utils import *
__all__ = []
__all__ += decoder.__all__
__all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__
__all__ += quantize.__all__
__all__ += utils.__all__
......@@ -13,10 +13,11 @@
# limitations under the License.
from __future__ import print_function
#from . import lookup_table_utils
#from .lookup_table_utils import *
from . import lookup_table_utils
from .lookup_table_utils import *
from . import hdfs_utils
from .hdfs_utils import *
#__all__ = lookup_table_utils.__all__
__all__ = hdfs_utils.__all__
__all__ = []
__all__ += lookup_table_utils.__all__
__all__ += hdfs_utils.__all__
......@@ -14,6 +14,7 @@
"""HDFS Utils"""
import os
import sys
import subprocess
import multiprocessing
from datetime import datetime
......@@ -24,7 +25,7 @@ import errno
import logging
__all__ = ["HDFSClient", "multi_download"]
__all__ = ["HDFSClient", "multi_download", "multi_upload"]
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
_logger = logging.getLogger("hdfs_utils")
......@@ -93,13 +94,15 @@ class HDFSClient(object):
def upload(self, hdfs_path, local_path, overwrite=False, retry_times=5):
"""
upload the local file to hdfs
Args:
hdfs_path: hdfs path, target path
local_path: local file path, source path
overwrite: will overwrite the original file
retry_times: max times retry to upload
Returns:
upload the local file to hdfs
Args:
hdfs_path(str): the hdfs file path
local_path(str): the local file path
overwrite(bool|None): will overwrite the file on HDFS or not
retry_times(int|5): retry times
Returns:
True or False
"""
assert hdfs_path is not None
......@@ -109,7 +112,7 @@ class HDFSClient(object):
_logger.warn(
"The Local path: {} is dir and I will support it later, return".
format(local_path))
return
return False
base = os.path.basename(local_path)
if not self.is_exist(hdfs_path):
......@@ -141,14 +144,16 @@ class HDFSClient(object):
def download(self, hdfs_path, local_path, overwrite=False, unzip=False):
"""
download from hdfs
Args:
hdfs_path: hdfs path, target path
local_path: local file path, source path
overwrite: will remove original file and overwrite it.
unzip: ignore this param
Returns
True or False
download file from HDFS
Args:
hdfs_path(str): the hdfs file path
local_path(str): the local file path
overwrite(bool|None): will overwrite the file on HDFS or not
unzip(bool|False): if the download file is compressed by zip, unzip it or not.
Returns:
True or False
"""
_logger.info('Downloading %r to %r.', hdfs_path, local_path)
_logger.info('Download of %s to %r complete.', hdfs_path, local_path)
......@@ -188,13 +193,13 @@ class HDFSClient(object):
def is_exist(self, hdfs_path=None):
"""
whether the remote hdfs path exists?
Args:
hdfs_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
fs_name: The default values are the same as in the job configuration
fs_ugi: The default values are the same as in the job configuration
Returns:
True or False
whether the remote HDFS path exists
Args:
hdfs_path(str): the hdfs file path
Returns:
True or False
"""
exist_cmd = ['-test', '-e', hdfs_path]
returncode, output, errors = self.__run_hdfs_cmd(
......@@ -211,13 +216,13 @@ class HDFSClient(object):
def is_dir(self, hdfs_path=None):
"""
whether the remote hdfs path exists?
Args:
remote_file_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp)
fs_name: The default values are the same as in the job configuration
fs_ugi: The default values are the same as in the job configuration
Returns:
True or False
whether the remote HDFS path is directory
Args:
hdfs_path(str): the hdfs file path
Returns:
True or False
"""
if not self.is_exist(hdfs_path):
......@@ -237,17 +242,17 @@ class HDFSClient(object):
def delete(self, hdfs_path):
"""
Remove a file or directory from HDFS.
Remove a file or directory from HDFS.
whether the remote HDFS path exists
Args:
param hdfs_path: HDFS path.
param recursive: Recursively delete files and directories. By default,
this method will raise an :class:`HdfsError` if trying to delete a
non-empty directory.
hdfs_path: HDFS path.
Returns:
True or False
This function returns `True` if the deletion was successful and `False` if
no file or directory previously existed at `hdfs_path`.
"""
_logger.info('Deleting %r.', hdfs_path)
......@@ -273,16 +278,14 @@ class HDFSClient(object):
def rename(self, hdfs_src_path, hdfs_dst_path, overwrite=False):
"""
Rename a file or folder.
Args:
:param hdfs_src_path: Source path.
:param hdfs_dst_path: Destination path. If the path already exists and is
a directory, the source will be moved into it. If the path exists and is
a file, or if a parent destination directory is missing, this method will
raise an :class:`HdfsError`.
Move a file or folder on HDFS.
Args:
hdfs_path(str): HDFS path.
overwrite(bool|False): If the path already exists and overwrite is False, will return False.
Returns:
This function returns `True` if the rename was successful and `False` if
rename was faild.
True or False
"""
assert hdfs_src_path is not None
assert hdfs_dst_path is not None
......@@ -320,17 +323,20 @@ class HDFSClient(object):
raise
def makedirs(self, hdfs_path):
"""Create a remote directory, recursively if necessary.
"""
Create a remote directory, recursively if necessary.
Args:
:param hdfs_path: Remote path. Intermediate directories will be created
appropriately.
hdfs_path(str): Remote path. Intermediate directories will be created appropriately.
Returns:
True if make a directories was successful, False when make a directiries was failed.
True or False
"""
_logger.info('Creating directories to %r.', hdfs_path)
assert hdfs_path is not None
if self.is_exist(hdfs_path):
_logger.error("HDFS path is exist: {}".format(hdfs_path))
return
mkdirs_commands = ['-mkdir', hdfs_path]
......@@ -346,11 +352,13 @@ class HDFSClient(object):
def ls(self, hdfs_path):
"""
ls a hdfs_path.
Args:
:param hdfs_path: hdfs_path will be ls.
ls directory contents about HDFS hdfs_path
Args:
hdfs_path(str): Remote HDFS path will be ls.
Returns:
This function returns a `list` that contaion all files in the hdfs_path.
List: a contents list about hdfs_path.
"""
assert hdfs_path is not None
......@@ -378,11 +386,15 @@ class HDFSClient(object):
def lsr(self, hdfs_path, only_file=True, sort=True):
"""
ls a hdfs_path sort by time.
Args:
:param hdfs_path: hdfs_path will be ls.
list directory contents about HDFS hdfs_path recursively
Args:
hdfs_path(str): Remote HDFS path.
only_file(bool|True): will discard folders.
sort(bool|True): will be sorted by create time.
Returns:
This function returns a `list` that contaion all files sorted by time in the hdfs_path.
List: a contents list about hdfs_path.
"""
def sort_by_time(v1, v2):
......@@ -422,21 +434,106 @@ class HDFSClient(object):
return ret_lines
def multi_download(client,
hdfs_path,
local_path,
trainer_id,
trainers,
multi_processes=5):
"""
Download files from HDFS using multi process.
Args:
client(HDFSClient): instance of HDFSClient
hdfs_path(str): path on hdfs
local_path(str): path on local
trainer_id(int): current trainer id
trainers(int): all trainers number
multi_processes(int|5): the download data process at the same time, default=5
Returns:
List:
Download files in local folder.
"""
def __subprocess_download(datas):
for data in datas:
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
if re_path == os.curdir:
sub_local_re_path = local_path
else:
sub_local_re_path = os.path.join(local_path, re_path)
client.download(data, sub_local_re_path)
assert isinstance(client, HDFSClient)
client.make_local_dirs(local_path)
_logger.info("Make local dir {} successfully".format(local_path))
all_need_download = client.lsr(hdfs_path, sort=True)
need_download = all_need_download[trainer_id::trainers]
_logger.info("Get {} files From all {} files need to be download from {}".
format(len(need_download), len(all_need_download), hdfs_path))
_logger.info("Start {} multi process to download datas".format(
multi_processes))
procs = []
for i in range(multi_processes):
process_datas = need_download[i::multi_processes]
p = multiprocessing.Process(
target=__subprocess_download, args=(process_datas, ))
procs.append(p)
p.start()
# complete the processes
for proc in procs:
proc.join()
_logger.info("Finish {} multi process to download datas".format(
multi_processes))
local_downloads = []
for data in need_download:
data_name = os.path.basename(data)
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
if re_path == os.curdir:
local_re_path = os.path.join(local_path, data_name)
else:
local_re_path = os.path.join(local_path, re_path, data_name)
local_downloads.append(local_re_path)
return local_downloads
def getfilelist(path):
rlist = []
for dir, folder, file in os.walk(path):
for i in file:
t = os.path.join(dir, i)
rlist.append(t)
for r in rlist:
print(r)
def multi_upload(client,
hdfs_path,
local_path,
multi_processes=5,
overwrite=False):
overwrite=False,
sync=True):
"""
Upload file to hdfs.
Upload files to HDFS using multi process.
Args:
:param overwrite: will overwrite hdfs file or not
:param multi_processes: the upload data process at the same time, default=5
:param client: instance of HDFSClient
:param hdfs_path: path on hdfs
:param local_path: path on local
client(HDFSClient): instance of HDFSClient
hdfs_path(str): path on hdfs
local_path(str): path on local
multi_processes(int|5): the upload data process at the same time, default=5
overwrite(bool|False): will overwrite file on HDFS or not
sync(bool|True): upload files sync or not.
Returns:
None
"""
def __subprocess_upload(datas):
......@@ -446,13 +543,6 @@ def multi_upload(client,
client.upload(hdfs_re_path, data, overwrite, retry_times=5)
def get_local_files(path):
"""
Get all local files
Args:
path: local file path
Returns:
A list that contation all files in the path.
"""
rlist = []
if not os.path.isdir(path):
......@@ -488,71 +578,6 @@ def multi_upload(client,
multi_processes))
def multi_download(client,
hdfs_path,
local_path,
trainer_id,
trainers,
file_cnt,
multi_processes=5):
"""
multi_download
Args:
:param client: instance of HDFSClient
:param hdfs_path: path on hdfs
:param local_path: path on local
:param trainer_id: current trainer id
:param trainers: all trainers number
:param file_cnt: all file number
:param multi_processes: the download data process at the same time, default=5
:return: None
Returns:
A list that be downloaded.
"""
def __subprocess_download(datas):
for data in datas:
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
local_re_path = os.path.join(local_path, re_path)
client.download(data, local_re_path)
assert isinstance(client, HDFSClient)
client.make_local_dirs(local_path)
_logger.info("Make local dir {} successfully".format(local_path))
all_need_download = client.lsr(hdfs_path, sort=True)[:file_cnt]
need_download = all_need_download[trainer_id::trainers]
_logger.info("Get {} files From all {} files need to be download from {}".
format(len(need_download), len(all_need_download), hdfs_path))
_logger.info("Start {} multi process to download datas".format(
multi_processes))
procs = []
for i in range(multi_processes):
process_datas = need_download[i::multi_processes]
p = multiprocessing.Process(
target=__subprocess_download, args=(process_datas, ))
procs.append(p)
p.start()
# complete the processes
for proc in procs:
proc.join()
_logger.info("Finish {} multi process to download datas".format(
multi_processes))
local_downloads = []
for data in need_download:
data_name = os.path.basename(data)
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
local_re_path = os.path.join(local_path, re_path, data_name)
local_downloads.append(local_re_path)
return local_downloads
if __name__ == "__main__":
hadoop_home = "/home/client/hadoop-client/hadoop/"
......
......@@ -18,14 +18,12 @@ import os
import time
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import io
from paddle.fluid import Program
__all__ = [
"load_inference_model", "load_persistable_vars",
"load_persistables_for_increment", "load_persistables_for_inference",
"convert_dist_to_sparse_program"
]
......@@ -80,19 +78,28 @@ def __get_prefetch_op_tuples(main_program):
return prefetch_op_tuples
def convert_dist_to_sparse_program(main_program):
if not main_program._distributed_lookup_table:
def convert_dist_to_sparse_program(program):
"""
WARNING: this function will only be used for distributed training with distributed lookup table.
when we train model with distributed lookup table but want to do the local inference, we can use
this function to convert the train program with distributed lookup table to sparse lookup table.
:param program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
:return:
program: The `program` is a Program, it's the program replace distributed lookup table to sparse lookup table.
"""
if not program._distributed_lookup_table:
_logger.warn(
"There are no distributed lookup tables need to be converted")
return
# create table param and grad var in pserver program
origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table)
emb_var = main_program._distributed_lookup_table
main_program.global_block()._rename_var(emb_var, origin_emb_var)
origin_param_var = main_program.global_block().vars[origin_emb_var]
origin_emb_var = "{}.origin".format(program._distributed_lookup_table)
emb_var = program._distributed_lookup_table
program.global_block()._rename_var(emb_var, origin_emb_var)
origin_param_var = program.global_block().vars[origin_emb_var]
param_var = main_program.global_block().create_var(
param_var = program.global_block().create_var(
name=emb_var,
shape=origin_param_var.shape,
dtype=origin_param_var.dtype,
......@@ -100,28 +107,28 @@ def convert_dist_to_sparse_program(main_program):
persistable=True)
# parameter must be selected rows
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
main_program._sync_with_cpp()
program._sync_with_cpp()
prefetch_op_tuples = __get_prefetch_op_tuples(main_program)
prefetch_op_tuples = __get_prefetch_op_tuples(program)
split_ids_id = prefetch_op_tuples[0]
for idx in range(split_ids_id + 2, split_ids_id - 1, -1):
main_program.global_block()._remove_op(idx)
main_program.desc.flush()
program.global_block()._remove_op(idx)
program.desc.flush()
in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2])
for in_out_pair in in_out_pairs:
idx = split_ids_id
ids = main_program.global_block().vars[in_out_pair[0]]
out = main_program.global_block().vars[in_out_pair[1]]
__insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out)
main_program.desc.flush()
return main_program
ids = program.global_block().vars[in_out_pair[0]]
out = program.global_block().vars[in_out_pair[1]]
__insert_lookup_sparse_table_op(program, idx, ids, param_var, out)
program.desc.flush()
return program
def load_persistable_vars(executor, dirname, program, lookup_table_var):
def _load_persistable_vars(executor, dirname, program, lookup_table_vars):
def _is_checkpoint_var(exclude_fluid_vars=None):
"""
the checkpoint will not save or load all the variables.
......@@ -159,8 +166,82 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var):
return is_valid
def _load_lookup_table_vars(executor, dirname, main_program,
lookup_table_vars):
io.load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var(lookup_table_vars),
filename=None)
def load_persistables_for_increment(dirname, executor, program,
lookup_table_var, lookup_table_var_path):
"""
WARNING: this function will only be used for distributed training with distributed lookup table.
for increment trainning, the pserver will not only load dense variables,
but also load the suitable lookup table var. Because of slice lookup table
var with HASH, we must load the correct slice var.
:param dirname(str): The directory path
:param executor(Executor): The executor to run for loading inference model.
:param program(Program): The parameter server program, which will run on Pserver.
:param lookup_table_var: the distributed lookup tables var name.
:param lookup_table_var_path: the the distributed lookup tables var location.
:return: None
"""
def __load_lookup_table_vars(executor, main_program, lookup_table_var,
lookup_table_var_path):
emb_var = main_program.global_block().var(lookup_table_var)
load_program = Program()
load_block = load_program.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [emb_var]},
attrs={'file_path': lookup_table_var_path})
executor.run(load_program)
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
if not os.path.exists(lookup_table_var_path):
raise ValueError("There is no file named '%s'", lookup_table_var_path)
if not isinstance(program, Program):
raise ValueError("program must be an instance of fluid.Program")
_logger.info("Start Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
_load_persistable_vars(executor, dirname, program, [lookup_table_var])
__load_lookup_table_vars(executor, program, lookup_table_var,
lookup_table_var_path)
_logger.info("Finish Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
def load_persistables_for_inference(dirname, executor, program,
lookup_table_var_name):
"""
WARNING: this function will only be used for inference with distributed lookup table.
Inference with distributed lookup table is a little funky, this function will load distributed
lookup table vars into sparse var, can be used in local inference mode.
:param dirname(str): The directory path
:param executor(Executor): The executor to run for loading inference model.
:param program(Program): The parameter server program, which will run on Pserver.
:param lookup_table_var_name: the distributed lookup tables var name.
:return: None
"""
def __load_lookup_table_vars(executor, dirname, main_program,
lookup_table_vars):
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
......@@ -209,48 +290,34 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var):
global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program)
_logger.info("Start Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
lookup_table_vars = [lookup_table_var]
io.load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var(lookup_table_vars),
filename=None)
_load_lookup_table_vars(executor, dirname, program, lookup_table_vars)
_logger.info("Finish Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
def load_inference_model(dirname, executor, lookup_table_var_name):
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
local_model = os.path.join(dirname, model_filename)
if program:
if not isinstance(program, Program):
raise ValueError("program must be an instance of fluid.Program")
else:
local_model = os.path.join(dirname, model_filename)
with open(local_model, "rb") as f:
program_desc_str = f.read()
with open(local_model, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
program = Program.parse_from_string(program_desc_str)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
# Binary data also need version.
load_persistable_vars(executor, dirname, program, lookup_table_var_name)
_logger.info("Start Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
_load_persistable_vars(executor, dirname, program, [lookup_table_var_name])
__load_lookup_table_vars(executor, dirname, program,
[lookup_table_var_name])
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
program.global_block().var(name) for name in fetch_target_names
]
_logger.info("Finish Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
return [program, feed_target_names, fetch_targets]
return program
......@@ -44,6 +44,8 @@ class DataToLoDTensorConverter(object):
self.dtype = 'int64'
elif dtype == core.VarDesc.VarType.FP64:
self.dtype = 'float64'
elif dtype == core.VarDesc.VarType.FP16:
self.dtype = 'float16'
elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32'
elif dtype == core.VarDesc.VarType.UINT8:
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import collections
import contextlib
import os
import re
import six
import sys
......@@ -27,11 +28,18 @@ from .proto import framework_pb2
try:
from . import core
except ImportError as e:
raise ImportError(
"""NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\"
if you encounters \"libmkldnn.so not found\" errors. If you have python
installed in other directory, replace \"/usr/local/lib\" with your own
directory. The original error is: \n""" + cpt.get_exception_message(e))
if os.name == 'nt':
raise ImportError(
"""NOTE: You may need to run \"set PATH=c:\python27\lib:%PATH%\"
if you encounters \"mkldnn.dll not found\" errors. If you have python
installed in other directory, replace \"c:\python27\lib" with your own
directory. The original error is: \n""" + cpt.get_exception_message(e))
else:
raise ImportError(
"""NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\"
if you encounters \"libmkldnn.so not found\" errors. If you have python
installed in other directory, replace \"/usr/local/lib\" with your own
directory. The original error is: \n""" + cpt.get_exception_message(e))
except Exception as e:
raise e
from . import unique_name
......@@ -563,8 +571,8 @@ class Operator(object):
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id'
'listen_and_serv', 'save_combine', 'load_combine', 'ncclInit', 'select',
'checkpoint_notify', 'gen_nccl_id'
}
def __init__(self,
......@@ -1316,6 +1324,9 @@ class Block(object):
def _prepend_op(self, *args, **kwargs):
op_desc = self.desc._prepend_op()
op = Operator(self, op_desc, *args, **kwargs)
if _in_imperative_mode():
_imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs],
[v._ivar for v in op.outputs], self.desc)
self.ops.insert(0, op)
return op
......
......@@ -28,7 +28,8 @@ def enabled():
def guard():
train = framework.Program()
startup = framework.Program()
tracer = core.Tracer(train.current_block().desc)
tracer = core.Tracer(train.current_block().desc,
startup.current_block().desc)
with framework.program_guard(train, startup):
with framework.unique_name.guard():
with framework._imperative_guard(tracer):
......
......@@ -25,11 +25,9 @@ __all__ = ['PyLayer']
class PyLayer(core.Layer):
def __init__(self):
pass
self._built = False
def __call__(self, inputs):
# TODO(panyx0718): Support declarative mode as well.
assert base.enabled()
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
inputs = [inputs]
......@@ -37,8 +35,15 @@ class PyLayer(core.Layer):
for x in inputs:
py_var = base.to_variable(x)
var_inputs.append(py_var)
if not self._built:
self._build_once(inputs)
self._built = True
outputs = self.forward(var_inputs)
return outputs
def _build_once(self, inputs):
pass
def forward(self, inputs):
return []
......@@ -18,6 +18,7 @@ from . import framework
import numpy as np
import contextlib
from .core import VarDesc
from . import unique_name
__all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
......@@ -207,16 +208,39 @@ class UniformInitializer(Initializer):
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(['gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op(
type="uniform_random",
outputs={"Out": var},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
"dtype": int(var.dtype),
"dtype": out_dtype,
"min": self._low,
"max": self._high,
"seed": self._seed
})
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op
......@@ -261,17 +285,39 @@ class NormalInitializer(Initializer):
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(['gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op(
type="gaussian_random",
outputs={"Out": var},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
"dtype": int(var.dtype),
"dtype": out_dtype,
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed,
"use_mkldnn": False
})
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op
......
......@@ -226,156 +226,6 @@ class BlockGuard(object):
return True
class ParallelDo(object):
"""
ParallelDo is used to represent multi-thread data parallel processing.
Its vanilla implementation can be shown as the following (:math:`|` means
single thread and :math:`||||` means multiple threads)
.. code-block:: text
In the forward pass
| Split input onto different devices
| Copy parameter onto different devices
|||| Compute forward pass in parallel
| Merge output from different devices
In the backward pass
| Split output@grad onto different devices
|||| Compute backward pass in parallel
| accumulate param@grad from different devices to the first device
| Merge input@grad from different devices
| Copy param@grad to the place of parallel_do_op
Examples:
.. code-block:: python
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# ParallelDo version & Single-thread version
if thread_num > 1:
places = fluid.layers.get_places(thread_num)
pd = fluid.layers.control_flow.ParallelDo(places)
with pd.do():
images = pd.read_input(images)
label = pd.read_input(label)
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
pd.write_output(avg_cost)
avg_cost = pd()
avg_cost = fluid.layers.mean(avg_cost)
else:
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
.. warning::
It will be soon deprecated, please use ParallelExecutor instead.
"""
def __init__(self, places, use_nccl=False, name=None):
warnings.warn(
"API ParallelDo is deprecated since 0.15.0. Please use ParallelExecutor instead.",
Warning)
self.helper = LayerHelper("parallel_do", name=name)
self.inputs = []
self.places = places
self.outputs = []
self.status = StaticRNN.BEFORE_RNN_BLOCK
self.use_nccl = use_nccl
def do(self):
return BlockGuardWithCompletion(self)
def parent_block(self):
prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def read_input(self, var):
self.inputs.append(var)
return var
def write_output(self, var):
self.outputs.append(var)
def get_parameters(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
local_inputs = set()
params = list()
for var in self.inputs:
local_inputs.add(var.name)
for op in current_block.ops:
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in local_inputs:
params.append(in_var_name)
for oname in op.output_names:
for out_var_name in op.output(oname):
local_inputs.add(out_var_name)
params = list(set(params))
return [parent_block.var(name) for name in params]
def _complete_op(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
self.outputs = [
parent_block.create_var(
name=o.name,
shape=o.shape,
dtype=o.dtype,
lod_level=o.lod_level,
persistable=o.persistable,
stop_gradient=o.stop_gradient) for o in self.outputs
]
inputs = [parent_block.var(i.name) for i in self.inputs]
outputs = [parent_block.var(o.name) for o in self.outputs]
parent_block.append_op(
type='parallel_do',
inputs={
'inputs': inputs,
'parameters': self.get_parameters(),
'places': self.places
},
outputs={'outputs': outputs,
'parallel_scopes': [step_scope]},
attrs={'sub_block': current_block,
'use_nccl': self.use_nccl})
class BlockGuardWithCompletion(BlockGuard):
"""
BlockGuardWithCompletion class.
......@@ -384,9 +234,8 @@ class BlockGuardWithCompletion(BlockGuard):
"""
def __init__(self, rnn):
if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)):
raise TypeError(
"BlockGuardWithCompletion takes a StaticRNN or ParallelDo")
if not isinstance(rnn, StaticRNN):
raise TypeError("BlockGuardWithCompletion takes a StaticRNN")
super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program)
self.rnn = rnn
......
......@@ -63,14 +63,18 @@ def noam_decay(d_model, warmup_steps):
Returns:
The decayed learning rate.
"""
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter(1)
a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * nn.elementwise_min(a, b)
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter(1)
return lr_value
a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * nn.elementwise_min(a, b)
return lr_value
return _lr_schedule
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
......@@ -109,15 +113,19 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
sgd_optimizer.minimize(avg_cost)
"""
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
return decayed_lr
return decayed_lr
return _lr_schedule
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
......@@ -138,15 +146,19 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Returns:
The decayed learning rate
"""
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
return decayed_lr
return decayed_lr
return _lr_schedule
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
......@@ -184,16 +196,20 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
staircase=True))
sgd_optimizer.minimize(avg_cost)
"""
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
decayed_lr = learning_rate / (1 + decay_rate * div_res)
div_res = global_step / decay_steps
if staircase:
div_res = ops.floor(div_res)
return decayed_lr
decayed_lr = learning_rate / (1 + decay_rate * div_res)
return decayed_lr
return _lr_schedule
def polynomial_decay(learning_rate,
......@@ -224,28 +240,33 @@ def polynomial_decay(learning_rate,
Returns:
Variable: The decayed learning rate
"""
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
if cycle:
div_res = ops.ceil(global_step / decay_steps)
zero_var = tensor.fill_constant(
shape=[1], dtype='float32', value=0.0)
one_var = tensor.fill_constant(
shape=[1], dtype='float32', value=1.0)
def _lr_schedule(dtype, decay_steps=decay_steps):
with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter()
if cycle:
div_res = ops.ceil(global_step / decay_steps)
zero_var = tensor.fill_constant(
shape=[1], dtype=dtype, value=0.0)
one_var = tensor.fill_constant(
shape=[1], dtype=dtype, value=1.0)
with control_flow.Switch() as switch:
with switch.case(global_step == zero_var):
tensor.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
decay_steps_var = tensor.fill_constant(
shape=[1], dtype=dtype, value=float(decay_steps))
global_step = nn.elementwise_min(
x=global_step, y=decay_steps_var)
with control_flow.Switch() as switch:
with switch.case(global_step == zero_var):
tensor.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
decay_steps_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(decay_steps))
global_step = nn.elementwise_min(x=global_step, y=decay_steps_var)
decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr
decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr
return _lr_schedule
def piecewise_decay(boundaries, values):
......@@ -273,38 +294,42 @@ def piecewise_decay(boundaries, values):
"""
with default_main_program()._lr_schedule_guard():
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
global_step = _decay_step_counter()
lr = tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
with control_flow.Switch() as switch:
for i in range(len(boundaries)):
boundary_val = tensor.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
tensor.assign(value_var, lr)
last_value_var = tensor.fill_constant(
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard():
if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1")
global_step = _decay_step_counter()
lr = tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
value=float(values[len(values) - 1]))
with switch.default():
tensor.assign(last_value_var, lr)
persistable=True,
name="learning_rate")
with control_flow.Switch() as switch:
for i in range(len(boundaries)):
boundary_val = tensor.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
tensor.assign(value_var, lr)
last_value_var = tensor.fill_constant(
shape=[1],
dtype='float32',
value=float(values[len(values) - 1]))
with switch.default():
tensor.assign(last_value_var, lr)
return lr
return lr
return _lr_schedule
def append_LARS(params_grads, learning_rate, weight_decay):
......
......@@ -29,6 +29,7 @@ from . import utils
from .. import unique_name
from functools import reduce
from .. import core
from ..imperative import layers
__all__ = [
'fc',
......@@ -2797,6 +2798,10 @@ def batch_norm(input,
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
......@@ -2831,7 +2836,7 @@ def batch_norm(input,
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=input.dtype)
dtype=dtype)
mean.stop_gradient = True
variance = helper.create_parameter(
......@@ -2841,7 +2846,7 @@ def batch_norm(input,
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=input.dtype)
dtype=dtype)
variance.stop_gradient = True
# create output
......@@ -9426,3 +9431,47 @@ def huber_loss(input, label, delta):
'Residual': residual},
attrs={'delta': delta})
return out
class FC(layers.PyLayer):
def __init__(self,
size,
param_attr=None,
num_flatten_dims=1,
dtype=core.VarDesc.VarType.FP32):
super(FC, self).__init__()
self._size = size
self._num_flatten_dims = num_flatten_dims
self._dtype = dtype
self._helper = LayerHelper('FC', param_attr=param_attr)
def _build_once(self, inputs):
input_shape = inputs[0].shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], 1)
] + [self._size]
self._w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=False)
def forward(self, inputs):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": inputs[0],
"Y": self._w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": [tmp]},
outputs={"Out": out},
attrs={"use_mkldnn": False})
return out
......@@ -50,17 +50,21 @@ class Optimizer(object):
def __init__(self, learning_rate, regularization=None, name=None):
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable):
raise TypeError("learning rate should be float or Variable")
not isinstance(learning_rate, framework.Variable) and \
not callable(learning_rate):
raise TypeError(
"learning rate should be float or Variable or callable(dtype)")
self._name = name
self.regularization = regularization
self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss
self._dtype = None
# each program should have a independent learning rate
# program -> Variable(learning_rate)
# program -> Variable(learning_rate) or:
# program -> callable(return learning_rate Variable)
self._learning_rate_map = dict()
if isinstance(self._learning_rate, framework.Variable):
if isinstance(self._learning_rate, framework.Variable) or \
callable(self._learning_rate):
self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate
# Dictionary of accumulators. Some optimizer subclasses need to
......@@ -75,6 +79,11 @@ class Optimizer(object):
if isinstance(lr, framework.Variable):
return
elif callable(lr):
dtype = 'float32' if self._dtype is None else self._dtype
self._learning_rate_map[framework.default_main_program()] = lr(
dtype)
return
else:
if not isinstance(self._learning_rate, float):
raise TypeError(
......
......@@ -15,7 +15,6 @@
from __future__ import print_function
from paddle.fluid.layers.device import get_places
from paddle.fluid.layers.control_flow import ParallelDo
import unittest
import paddle.fluid as fluid
import paddle
......@@ -147,22 +146,7 @@ def train(word_dict,
cost, acc_out, prediction = net_method(
data, label, input_dim=dict_dim, class_dim=class_dim)
else:
places = get_places()
pd = ParallelDo(places)
with pd.do():
cost, acc, _ = net_method(
pd.read_input(data),
pd.read_input(label),
input_dim=dict_dim,
class_dim=class_dim)
pd.write_output(cost)
pd.write_output(acc)
cost, acc = pd()
cost = fluid.layers.mean(cost)
acc_out = fluid.layers.mean(acc)
prediction = None
assert save_dirname is None
raise NotImplementedError()
adagrad = fluid.optimizer.Adagrad(learning_rate=0.002)
adagrad.minimize(cost)
......
......@@ -25,7 +25,6 @@ import numpy
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
from paddle.fluid.layers.control_flow import ParallelDo
BATCH_SIZE = 64
......@@ -82,19 +81,7 @@ def train(nn_type,
net_conf = conv_net
if parallel:
places = get_places()
pd = ParallelDo(places)
with pd.do():
img_ = pd.read_input(img)
label_ = pd.read_input(label)
prediction, avg_loss, acc = net_conf(img_, label_)
for o in [avg_loss, acc]:
pd.write_output(o)
avg_loss, acc = pd()
# get mean loss and acc through every devices.
avg_loss = fluid.layers.mean(avg_loss)
acc = fluid.layers.mean(acc)
raise NotImplementedError()
else:
prediction, avg_loss, acc = net_conf(img, label)
......@@ -273,7 +260,7 @@ def inject_all_tests():
for use_cuda in (False, True):
if use_cuda and not core.is_compiled_with_cuda():
continue
for parallel in (False, True):
for parallel in (False, ):
for nn_type in ('mlp', 'conv'):
inject_test_method(use_cuda, parallel, nn_type, True)
......
......@@ -17,7 +17,6 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
from paddle.fluid.layers.control_flow import ParallelDo
import unittest
import os
import numpy as np
......@@ -84,18 +83,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
avg_cost, predict_word = __network__(
[first_word, second_word, third_word, forth_word, next_word])
else:
places = get_places()
pd = ParallelDo(places)
with pd.do():
avg_cost, predict_word = __network__(
list(
map(pd.read_input, [
first_word, second_word, third_word, forth_word,
next_word
])))
pd.write_output(avg_cost)
avg_cost = fluid.layers.mean(pd())
raise NotImplementedError()
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
......@@ -262,7 +250,7 @@ def inject_test_method(use_cuda, is_sparse, is_parallel):
for use_cuda in (False, True):
for is_sparse in (False, True):
for is_parallel in (False, True):
for is_parallel in (False, ):
inject_test_method(use_cuda, is_sparse, is_parallel)
if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
import sys
import paddle
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
from paddle.fluid.layers.control_flow import ParallelDo
# need to fix random seed and training data to compare the loss
# value accurately calculated by the default and the memory optimization
# version.
fluid.default_startup_program().random_seed = 111
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
device_type = 'CPU'
use_nccl = False
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
device_type = 'CUDA'
use_nccl = False
place = fluid.CUDAPlace(0)
places = get_places(device_count=0, device_type=device_type)
pd = ParallelDo(places, use_nccl=use_nccl)
with pd.do():
x_ = pd.read_input(x)
y_ = pd.read_input(y)
y_predict = fluid.layers.fc(input=x_, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y_)
avg_cost = fluid.layers.mean(x=cost)
pd.write_output(avg_cost)
cost = pd()
avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
# fluid.release_memory(fluid.default_main_program())
BATCH_SIZE = 200
# fix the order of training data
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE, drop_last=False)
# train_reader = paddle.batch(
# paddle.reader.shuffle(
# paddle.dataset.uci_housing.train(), buf_size=500),
# batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
avg_loss_value, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_loss_value[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.
print(avg_loss_value[0])
if math.isnan(float(avg_loss_value)):
sys.exit("got NaN loss, training failed.")
exit(1)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh
class TestNGRAPHReluDim2(TestRelu):
def setUp(self):
super(TestNGRAPHReluDim2, self).setUp()
class TestNGRAPHTanhDim2(TestTanh):
def setUp(self):
super(TestNGRAPHTanhDim2, self).setUp()
class TestNGRAPHReluDim4(TestRelu):
def setUp(self):
super(TestNGRAPHReluDim4, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
class TestNGRAPHTanhDim4(TestTanh):
def setUp(self):
super(TestNGRAPHTanhDim4, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_mul_op import TestMulOp, TestMulOp2, TestFP16MulOp1, TestFP16MulOp2
class TestNGRAPHMulOp(TestMulOp):
def init_dtype_type(self):
pass
class TestNGRAPHMulOp2(TestMulOp2):
def init_dtype_type(self):
pass
class TestNGRAPHFP16MulOp1(TestFP16MulOp1):
def init_dtype_type(self):
pass
class TestNGRAPHFP16MulOp2(TestFP16MulOp2):
def init_dtype_type(self):
pass
if __name__ == "__main__":
unittest.main()
......@@ -368,6 +368,8 @@ class OpTest(unittest.TestCase):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
return [place]
else:
return []
else:
return []
places = [fluid.CPUPlace()]
......
......@@ -22,8 +22,10 @@ from op_test import OpTest
class TestAccuracyOp(OpTest):
def setUp(self):
self.op_type = "accuracy"
self.dtype = np.float32
self.init_dtype()
n = 8192
infer = np.random.random((n, 1)).astype("float32")
infer = np.random.random((n, 1)).astype(self.dtype)
indices = np.random.randint(0, 2, (n, 1))
label = np.random.randint(0, 2, (n, 1))
self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
......@@ -34,14 +36,25 @@ class TestAccuracyOp(OpTest):
num_correct += 1
break
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32"),
'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype),
'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32")
}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestAccuracyOpFp16(TestAccuracyOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
if __name__ == '__main__':
unittest.main()
......@@ -21,14 +21,16 @@ from op_test import OpTest
class ElementwiseDivOp(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.dtype = np.float32
self.init_dtype()
""" Warning
CPU gradient check error!
'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32")
"""
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
......@@ -46,6 +48,9 @@ class ElementwiseDivOp(OpTest):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
def init_dtype(self):
pass
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self):
......@@ -126,5 +131,21 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
}
class TestElementwiseDivOpFp16(ElementwiseDivOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
if __name__ == '__main__':
unittest.main()
......@@ -135,5 +135,10 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
}
class TestElementwiseMulOpFp16(ElementwiseMulOp):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
......@@ -22,12 +22,22 @@ from op_test import OpTest
class TestFillZerosLikeOp(OpTest):
def setUp(self):
self.op_type = "fill_zeros_like"
self.inputs = {'X': np.random.random((219, 232)).astype("float32")}
self.dtype = np.float32
self.init_dtype()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestFillZerosLikeOpFp16(TestFillZerosLikeOp):
def init_dtype(self):
self.dtype = np.float16
if __name__ == "__main__":
unittest.main()
......@@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import unittest
import sys
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.layers.nn import FC
@contextlib.contextmanager
def new_program_scope():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
class MyLayer(fluid.imperative.PyLayer):
......@@ -30,6 +41,23 @@ class MyLayer(fluid.imperative.PyLayer):
return [fluid.layers.elementwise_mul(x, x)]
class MLP(fluid.imperative.PyLayer):
def __init__(self):
super(MLP, self).__init__()
self._fc1 = FC(3,
fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
self._fc2 = FC(4,
fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
def forward(self, inputs):
x = self._fc1(inputs[0])
x = self._fc2(x)
x = fluid.layers.reduce_sum(x)
return x
class TestImperative(unittest.TestCase):
def test_layer(self):
with fluid.imperative.guard():
......@@ -39,13 +67,56 @@ class TestImperative(unittest.TestCase):
l.forward([])
def test_layer_in_out(self):
np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32)
with fluid.imperative.guard():
l = MyLayer()
x = l(np.array([1.0, 2.0, -1.0], dtype=np.float32))[0]
x = l(np_inp)[0]
self.assertIsNotNone(x)
sys.stderr.write("%s output: %s\n" % (x, x._numpy()))
dy_out = x._numpy()
x._backward()
sys.stderr.write("grad %s\n" % l._x_for_debug._gradient())
dy_grad = l._x_for_debug._gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[3], append_batch_size=False)
l = MyLayer()
x = l(inp)[0]
param_grads = fluid.backward.append_backward(
x, parameter_list=[l._x_for_debug.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
fetch_list=[x.name, param_grads[1].name])
self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad, static_grad))
def test_mlp(self):
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
with fluid.imperative.guard():
mlp = MLP()
out = mlp(np_inp)
dy_out = out._numpy()
out._backward()
dy_grad = mlp._fc1._w._gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[2, 2], append_batch_size=False)
mlp = MLP()
out = mlp(inp)
param_grads = fluid.backward.append_backward(
out, parameter_list=[mlp._fc1._w.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
fetch_list=[out.name, param_grads[1].name])
self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad, static_grad))
if __name__ == '__main__':
......
......@@ -97,7 +97,7 @@ class TestLearningRateDecay(unittest.TestCase):
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
decayed_lr = fluid_decay_fn(**kwargs)
decayed_lr = fluid_decay_fn(**kwargs)("float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......
......@@ -24,11 +24,13 @@ from op_test import OpTest
class TestMomentumOp1(OpTest):
def setUp(self):
self.op_type = "momentum"
self.dtype = np.float32
self.init_dtype()
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(self.dtype)
mu = 0.0001
use_nesterov = False
......@@ -50,10 +52,21 @@ class TestMomentumOp1(OpTest):
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestMomentumOpFp16(TestMomentumOp1):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
class TestMomentumOp2(OpTest):
'''Test Momentum with default values for attributes
'''
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from paddle.fluid.layers.device import get_places
from paddle.fluid.layers.control_flow import ParallelDo
import paddle.fluid.profiler as profiler
import numpy
import six
class BaseParallelForTest(unittest.TestCase):
def run_test(self, callback, feed, fetch):
"""
Run the unittest for parallel.for
Args:
callback(callable): A callable function returns a generator. There
are two yields in the generator function. The first yield
returns the data layers, and the second yield returns the loss.
The modified data variables will be sent back during the first
yield.
feed(dict): The executor feeding dictionary.
fetch(list|basestr): The fetch name lists.
Returns:
None
Raises:
AssertionError when the computation of cpu, parallel.for in cpu,
gpu, parallel.for in gpu are different.
"""
cpu = fluid.CPUPlace()
result_cpu = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=cpu,
use_parallel=False)
result_cpu_parallel = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=cpu,
use_parallel=True)
if fluid.core.is_compiled_with_cuda():
gpu = fluid.CUDAPlace(0)
result_gpu = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=gpu,
use_parallel=False,
use_gpu=True)
result_gpu_parallel = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=gpu,
use_parallel=True,
use_gpu=True)
result_gpu_nccl = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=gpu,
use_parallel=True,
use_nccl=True,
use_gpu=True)
self._assert_same_(fetch, result_cpu, result_cpu_parallel,
result_gpu, result_gpu_parallel, result_gpu_nccl)
else:
self._assert_same_(fetch, result_cpu, result_cpu_parallel)
def _run_test_impl_(self,
callback,
feed,
fetch,
place,
use_parallel=False,
use_nccl=False,
use_gpu=False):
"""
Run a single test, returns the fetch values
Args:
place(Place): the computation place.
use_parallel(bool): Whether use parallel.for or not.
Returns:
Fetched numpy arrays.
"""
if isinstance(fetch, six.string_types):
fetch = [fetch]
main = fluid.Program()
startup = fluid.Program()
# Fix seed
main.random_seed = 10
startup.random_seed = 10
with fluid.program_guard(main, startup):
generator = callback()
# Automatically insert parallel do if use_parallel = True
if use_parallel:
thread_num = fluid.core.get_cuda_device_count(
) if use_gpu else 8
places = get_places(thread_num)
pd = ParallelDo(places, use_nccl=use_nccl)
data = next(generator)
if isinstance(data, fluid.framework.Variable):
data = [data]
with pd.do():
ins = list(map(pd.read_input, data))
if len(ins) == 1:
ins = ins[0]
loss = generator.send(ins) # patch input
pd.write_output(loss)
loss = pd()
else:
data = next(generator)
loss = generator.send(data)
self.assertIsNotNone(loss)
avg_loss = fluid.layers.mean(loss)
fluid.backward.append_backward(loss=avg_loss)
exe = fluid.Executor(place)
exe.run(startup)
if use_gpu:
profile_type = 'GPU'
else:
profile_type = 'CPU'
with profiler.profiler(profile_type, 'total', '/tmp/profiler'):
return exe.run(main, feed=feed, fetch_list=fetch)
def _assert_same_(self, fetch, *args):
"""
Assert the return values of `run_test` are same.
Args:
fetch: Fetch list. Used for print error message
*args: The fetch result lists of each situations.
Returns:
None
Raises:
AssertionError
"""
def _impl_(a, b, fetch_id, item_id):
item_str = [
'CPU', 'ParallelCPU', 'GPU', 'ParallelGPU', 'ParallelGPUNCCL'
]
flag = numpy.allclose(a, b, rtol=0.1, atol=1e-3)
self.assertTrue(flag,
"The {0} are different in {1}, {2} vs {3}".format(
fetch[fetch_id], item_str[item_id], a, b))
for i, items in enumerate(zip(*args)):
self.assertGreater(len(items), 0)
for j in range(1, len(items)):
_impl_(items[0], items[j], fetch_id=i, item_id=j)
class ParallelOpTest(BaseParallelForTest):
@staticmethod
def __network__():
x = fluid.layers.data(shape=[784], dtype='float32', name='img')
x = yield x
hidden = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden = fluid.layers.batch_norm(input=hidden)
loss = fluid.layers.mean(hidden)
yield loss
def test_simple_fc(self):
self.run_test(
callback=self.__network__,
feed={
'img': numpy.random.random(size=(51, 784)).astype('float32')
},
fetch=['fc1.w@GRAD'])
def test_fc_with_tiny_data(self):
self.run_test(
callback=self.__network__,
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
fetch=['fc1.w@GRAD'])
class ParallelOpTestMultipleInput(BaseParallelForTest):
@staticmethod
def __network__():
x = fluid.layers.data(
shape=[784], dtype='float32', name='img1', stop_gradient=False)
y = fluid.layers.data(
shape=[784], dtype='float32', name='img2', stop_gradient=False)
yield [x, y]
x = x + y
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.fc(input=hidden1, size=200, param_attr='fc2.w')
hidden3 = fluid.layers.fc(input=hidden2, size=200, param_attr='fc3.w')
loss = fluid.layers.mean(hidden3)
yield loss
def test_simple_fc(self):
self.run_test(
callback=self.__network__,
feed={
'img1': numpy.random.random(size=(51, 784)).astype('float32'),
'img2': numpy.random.random(size=(51, 784)).astype('float32')
},
fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD'])
if __name__ == '__main__':
unittest.main()
......@@ -23,8 +23,11 @@ class TestTopkOp(OpTest):
def setUp(self):
self.set_args()
self.op_type = "top_k"
self.dtype = np.float32
self.init_dtype()
k = self.top_k
input = np.random.random((self.row, k)).astype("float32")
input = np.random.random((self.row, k)).astype(self.dtype)
output = np.ndarray((self.row, k))
indices = np.ndarray((self.row, k)).astype("int64")
......@@ -38,6 +41,9 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices}
def init_dtype(self):
pass
def set_args(self):
self.row = 32
self.top_k = 1
......@@ -46,6 +52,11 @@ class TestTopkOp(OpTest):
self.check_output()
class TestTopkOpFp16(TestTopkOp):
def init_dtype(self):
self.dtype = np.float16
class TestTopkOp3d(OpTest):
def setUp(self):
self.op_type = "top_k"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_transpose_op import TestTransposeOp
class TestTransposeMKLDNN(TestTransposeOp):
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = True
self.is_test = True
return
def test_check_grad(self):
return
def test_check_grad_no_input(self):
return
def test_check_grad_no_filter(self):
return
class TestCase0MKLDNN(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (3, )
self.axis = (0, )
class TestCase1a(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (3, 4, 5)
self.axis = (0, 2, 1)
class TestCase1b(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (3, 4, 5)
self.axis = (2, 1, 0)
class TestCase2(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
if __name__ == '__main__':
unittest.main()
......@@ -21,15 +21,24 @@ from op_test import OpTest
class TestTransposeOp(OpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.op_type = "transpose2"
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
'is_test': self.is_test,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis)
}
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = False
self.is_test = False
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
......
......@@ -35,11 +35,10 @@ dtype_to_size = {
}
SUB_BLOCK_OPS = [
"while", "while_grad", "parallel_do", "parallel_do_grad",
"conditional_block", "conditional_block_grad"
"while", "while_grad", "conditional_block", "conditional_block_grad"
]
SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"),
SUB_BLOCK_PAIR = [("while", "while_grad"),
("conditional_block", "conditional_block_grad")]
PRINT_LOG = False
......
......@@ -107,9 +107,9 @@ packages=['paddle',
'paddle.fluid.distributed',
'paddle.fluid.layers',
'paddle.fluid.contrib',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
......@@ -160,10 +160,11 @@ if '${WITH_FLUID_ONLY}'== 'OFF':
# put all thirdparty libraries in paddle.libs
libs_path='${PADDLE_BINARY_DIR}/python/paddle/libs'
if os.name != 'nt':
package_data['paddle.libs']= []
package_data['paddle.libs']=['libwarpctc' + ext_name]
shutil.copy('${WARPCTC_LIBRARIES}', libs_path)
package_data['paddle.libs']= []
package_data['paddle.libs']=[('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_name]
shutil.copy('${WARPCTC_LIBRARIES}', libs_path)
if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_LIB}', libs_path)
shutil.copy('${MKLML_IOMP_LIB}', libs_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册