提交 b8051e79 编写于 作者: S sneaxiy

merge develop

test=develop
......@@ -208,6 +208,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
......
......@@ -123,6 +123,8 @@ class OpDesc {
BlockDesc *Block() { return this->block_; }
const BlockDesc *Block() const { return this->block_; }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -25,6 +25,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
class InferShapeContext {
......
......@@ -254,5 +254,16 @@ TEST(Analyzer_dam, compare) { compare(); }
TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
// Compare Deterministic result
TEST(Analyzer_dam, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
......@@ -180,6 +180,17 @@ TEST(Analyzer_LAC, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_LAC, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -179,5 +179,16 @@ TEST(Analyzer_Chinese_ner, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_Chinese_ner, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
......@@ -85,6 +85,17 @@ TEST(Analyzer_resnet50, compare) { compare(); }
TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
// Compare Deterministic result
TEST(Analyzer_resnet50, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -265,6 +265,17 @@ TEST(Analyzer_rnn1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_rnn1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
// Test Multi-Thread.
TEST(Analyzer_rnn1, multi_thread) {
contrib::AnalysisConfig cfg;
......
......@@ -158,5 +158,16 @@ TEST(Analyzer_rnn2, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_rnn2, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
......@@ -204,5 +204,16 @@ TEST(Analyzer_seq_conv1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_seq_conv1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
......@@ -106,6 +106,17 @@ TEST(Analyzer_Text_Classification, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_Text_Classification, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
AnalysisConfig cfg;
SetConfig(&cfg);
......
......@@ -145,6 +145,17 @@ TEST(Analyzer_vis, compare) { compare(); }
TEST(Analyzer_vis, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
// Compare Deterministic result
TEST(Analyzer_vis, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -45,6 +45,7 @@ DEFINE_bool(use_analysis, true,
"Running the inference program in analysis mode.");
DEFINE_bool(record_benchmark, false,
"Record benchmark after profiling the model");
DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads);
......@@ -85,7 +86,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3);
EXPECT_NEAR(pdata_ref[j], pdata[j], FLAGS_accuracy);
}
break;
}
......@@ -283,6 +284,26 @@ void TestPrediction(const PaddlePredictor::Config *config,
}
}
void CompareDeterministic(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
auto predictor = CreateTestPredictor(config, FLAGS_use_analysis);
// warmup run
std::vector<PaddleTensor> warmup_outputs, outputs;
predictor->Run(inputs[0], &warmup_outputs, batch_size);
// run num_times to Compare Deterministic Result.
for (int i = 0; i < num_times; i++) {
for (size_t j = 0; j < inputs.size(); j++) {
predictor->Run(inputs[j], &outputs, batch_size);
CompareResult(outputs, warmup_outputs);
}
}
}
void CompareNativeAndAnalysis(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
......
......@@ -42,8 +42,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)
......@@ -92,4 +91,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/py_func_op.h"
#include <set>
#include <string>
#include <vector>
#include "Python.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
namespace py = ::pybind11;
static std::vector<py::object> g_py_callables;
const char kForwardPythonCallableId[] = "forward_callable_id";
const char kBackwardPythonCallableId[] = "backward_callable_id";
const char kPyFuncBackwardSkipVars[] = "backward_skip_vars";
size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
g_py_callables.emplace_back(py_obj);
return g_py_callables.size() - 1;
}
// Return py::object* instead of py::object
// Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe
static py::object *GetPythonCallableObject(size_t i) {
PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id");
return &g_py_callables[i];
}
static std::string PythonFuncDebugString(const py::object &py_callable) {
py::gil_scoped_acquire guard;
std::string wrapper_func_str = py::str(py_callable);
auto inner_func = py_callable.attr("_func");
std::string inner_func_str = py::str(inner_func);
return inner_func_str + " wrapped by " + wrapper_func_str;
}
static void CallPythonFunc(py::object *callable,
const std::vector<framework::LoDTensor> &ins,
std::vector<framework::LoDTensor *> *outs) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
}
auto ret = (*callable)(*in_args);
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
size_t out_num = outs->size();
if (UNLIKELY(ret_num != out_num)) {
// Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num
PADDLE_ENFORCE(
ret_num == 1 && out_num == 0 &&
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr,
"Output number not match. Expected %d, actual %d", out_num, ret_num);
}
for (size_t i = 0; i < out_num; ++i) {
auto *out = (*outs)[i];
if (out == nullptr) {
continue;
}
try {
auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
"Output tensor %d should not be nullptr", i);
out->set_lod(py_out_tensor->lod());
out->ShareDataWith(*py_out_tensor);
} catch (py::cast_error &) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}
}
}
class PyFuncOpVarTypInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op,
framework::BlockDesc *block) const override {
auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
auto &ins = op.Inputs();
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
/**
* X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output
*/
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0,
"Function id cannot be less than 0");
if (!has_out) return;
/**
* Traverse all outputs, check if name of any output ends with @GRAD.
* If found, set its shape, dtype, lod_level, type to be the same as
* the corresponding forward variable
*/
const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = outs.at("Out");
for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) {
continue;
}
size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len);
auto *out_var_desc = block->FindVarRecursive(out_var_name);
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found",
out_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name);
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")";
out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(fwd_var_desc->GetDataType());
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel());
out_var_desc->SetType(fwd_var_desc->GetType());
}
}
}
};
class PyFuncOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
}
};
class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Inputs of py_func op.").AsDuplicable();
AddOutput("Out", "Outputs of py_func op").AsDuplicable();
AddAttr<int>(kForwardPythonCallableId,
"Index of registered forward Python function.")
.SetDefault(0);
AddAttr<int>(kBackwardPythonCallableId,
"Index of registered backward Python function.")
.SetDefault(-1);
AddAttr<std::vector<std::string>>(kPyFuncBackwardSkipVars,
"Unused forward in/out in backward op")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC("PyFunc Op")DOC");
}
};
/**
* There are several benefits when backward op of py_func op is
* still py_func op.
*
* - Less codes are needed, since codes of backward is almost
* the same as forward.
*
* - To support high order derivative, so that py_func is
* infinite-order differentiable
*/
class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
private:
static std::string DebugString(const std::vector<std::string> &strs) {
if (strs.empty()) return "";
std::string ret = strs[0];
for (size_t i = 1; i < strs.size(); ++i) {
ret += " ";
ret += strs[i];
}
return ret;
}
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto &fwd_attrs = Attrs();
// no backward op when backward_id is less than 0
if (boost::get<int>(fwd_attrs.at(kBackwardPythonCallableId)) < 0) {
return {};
}
std::unique_ptr<framework::OpDesc> grad_op(new framework::OpDesc());
grad_op->SetType("py_func");
framework::AttributeMap bwd_attrs;
bwd_attrs[kForwardPythonCallableId] =
fwd_attrs.at(kBackwardPythonCallableId);
bwd_attrs[kBackwardPythonCallableId] = -1;
grad_op->SetAttrMap(bwd_attrs);
// All forward inputs
auto fwd_ins = Input("X");
// All forward outputs
auto fwd_outs = Output("Out");
// For memory reused, some inputs/output in forward part may be not needed
// in backward part. Skipping these vars helps to save memory
auto &backward_skip_var_list = boost::get<std::vector<std::string>>(
fwd_attrs.at(kPyFuncBackwardSkipVars));
std::unordered_set<std::string> backward_skip_var_set(
backward_skip_var_list.begin(), backward_skip_var_list.end());
std::vector<std::string> bwd_ins;
bwd_ins.reserve(fwd_ins.size() + fwd_outs.size());
for (auto &fwd_in : fwd_ins) {
if (backward_skip_var_set.count(fwd_in) == 0) {
bwd_ins.emplace_back(fwd_in);
}
}
for (auto &fwd_out : fwd_outs) {
if (backward_skip_var_set.count(fwd_out) == 0) {
bwd_ins.emplace_back(fwd_out);
}
}
// Backward OG cannot be skipped
// But in Python side, if OG is kEmptyVarName, input tensor would be None
auto fwd_out_grads = OutputGrad("Out");
bwd_ins.reserve(bwd_ins.size() + fwd_out_grads.size());
bwd_ins.insert(bwd_ins.end(), fwd_out_grads.begin(), fwd_out_grads.end());
// Backward IG cannot be skipped
// But in Python side, if IG is not needed, users can just return None
auto bwd_outs = InputGrad("X", false);
VLOG(10) << "PyFunc Grad Input: " << DebugString(bwd_ins);
VLOG(10) << "PyFunc Grad Output: " << DebugString(bwd_outs);
grad_op->SetInput("X", bwd_ins);
grad_op->SetOutput("Out", bwd_outs);
std::vector<std::unique_ptr<framework::OpDesc>> ret(1);
ret[0] = std::move(grad_op);
return ret;
}
};
class PyFuncOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
protected:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in_arg_names = Inputs("X");
auto &out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> inputs(in_arg_names.size());
for (size_t i = 0; i < in_arg_names.size(); ++i) {
auto in_var = scope.FindVar(in_arg_names[i]);
// When py_func op is called in backward, in_var may be null
if (in_var == nullptr) {
continue;
}
auto &in_tensor = in_var->Get<framework::LoDTensor>();
if (!in_tensor.IsInitialized()) {
continue;
}
if (platform::is_gpu_place(in_tensor.place())) {
framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]);
} else {
inputs[i].ShareDataWith(in_tensor);
}
inputs[i].set_lod(in_tensor.lod());
}
std::vector<framework::LoDTensor *> outputs(out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) {
auto *out_var = scope.FindVar(out_arg_names[i]);
outputs[i] =
out_var ? out_var->GetMutable<framework::LoDTensor>() : nullptr;
}
auto callable_id = static_cast<size_t>(Attr<int>(kForwardPythonCallableId));
auto *py_callable = GetPythonCallableObject(callable_id);
VLOG(10) << "Call Python function with id " << callable_id << ": "
<< PythonFuncDebugString(*py_callable);
CallPythonFunc(py_callable, inputs, &outputs);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
namespace paddle {
namespace operators {
size_t AppendPythonCallableObjectAndReturnId(const ::pybind11::object &py_obj);
} // namespace operators
} // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc)
if(WITH_PYTHON)
......
......@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) {
.def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block,
.def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference);
}
......
......@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -110,6 +111,12 @@ PYBIND11_MODULE(core, m) {
BindException(&m);
m.def(
"_append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t {
return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj);
});
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
.def(py::init<>())
.def("_run_backward",
......
......@@ -18,7 +18,9 @@ All layers just related to the neural network.
from __future__ import print_function
import numpy as np
import six
import os
import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable, OpProtoHolder
......@@ -176,6 +178,7 @@ __all__ = [
'merge_selected_rows',
'get_tensor_from_selected_rows',
'lstm',
'py_func',
'psroi_pool',
'huber_loss',
]
......@@ -9327,6 +9330,224 @@ def get_tensor_from_selected_rows(x, name=None):
return out
class PyFuncRegistry(object):
_register_funcs = []
def __init__(self, func):
if func is None or not callable(func):
raise TypeError('func must be a Python function')
self._func = func
# find named args using reflection
args = inspect.getargspec(self._func)
if len(args[0]) == 0 and args[1] is None and args[2] is None:
# Function with no inputs
self._named_args = None
else:
self._named_args = args[0]
self._id = core._append_python_callable_object_and_return_id(self)
'''
Why record self here?
1. For debug usage. Users can call
:code:`py_func.registered_func(idx)` method
to find the registered function corresponding
to :code:`idx`.
2. For increasing reference count of self.
It seems that to release Python object
whose reference count is 1 would cause
segmentation fault error in C++ side.
May be lack of Python GC in C++ side?
'''
PyFuncRegistry._register_funcs.append(self)
@classmethod
def registered_func(cls, idx):
return cls._register_funcs[idx]._func
@classmethod
def registered_func_num(cls):
return len(cls._register_funcs)
@property
def id(self):
return self._id
def __call__(self, *args):
if self._named_args is None:
func_ret = self._func()
else:
kwargs = dict()
idx = 0
for arg in self._named_args:
kwargs[arg] = args[idx]
idx += 1
func_ret = self._func(*args[idx:], **kwargs)
if not isinstance(func_ret, (list, tuple)):
func_ret = (func_ret, )
ret = []
for each_ret in func_ret:
if each_ret is None or isinstance(each_ret, core.LoDTensor):
ret.append(each_ret)
continue
if not isinstance(each_ret, np.ndarray):
each_ret = np.array(each_ret)
tensor = core.LoDTensor()
tensor.set(each_ret, core.CPUPlace())
ret.append(tensor)
return tuple(ret)
@templatedoc()
def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
"""
PyFunc Operator.
User can use :code:`py_func` to register operators in Python side.
The inputs of :code:`func` is :code:`LoDTensor` and outputs can be
numpy array or :code:`LoDTensor`. Paddle would call the registered
:code:`func` in forward part, and call :code:`backward_func` in
backward part (if :code:`backward_func` is not None).
User should set the right data type and shape of :code:`out` before
calling this function. However, data types and shapes of gradients of
:code:`out` and :code:`x` would be inferred automatically.
Input orders of :code:`backward_func` would be: forward inputs
:code:`x`, forward outputs :code:`out` and backward input gradients of
:code:`out`. If some variables of :code:`out` have no gradient, the input
tensor would be None in Python side. If some variables of :code:`in` have
no gradient, users should return None.
This function can also be used to debug the running network. User can
add a :code:`py_func` operator without output, and print input
:code:`x` inside :code:`func`.
Args:
func (callable): forward Python function.
x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`.
out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`.
Paddle cannot infer shapes and data types of :code:`out`. Users
should create :code:`out` beforehand.
backward_func (callable|None): backward Python function.
None means no backward. Default None.
skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)):
Variables that are not needed in :code:`backward_func` inputs.
These variables must be any of :code:`x` and :code:`out`.
If set, these vars would not be inputs of :code:`backward_func`,
Only useful when :code:`backward_func` is not None. Default None.
Returns:
out (Variable|list(Variable)|tuple(Variable)): input :code:`out`
Examples:
>>> import paddle.fluid as fluid
>>> import six
>>>
>>> def create_tmp_var(name, dtype, shape):
>>> return fluid.default_main_program().current_block().create_var(
>>> name=name, dtype=dtype, shape=shape)
>>>
>>> # tanh activation has been provided by Paddle C++ op
>>> # Here, we only use tanh to be an example to show the usage
>>> # of py_func
>>> def tanh(x):
>>> return np.tanh(x)
>>>
>>> # forward input x is skipped
>>> def tanh_grad(y, dy):
>>> return np.array(dy) * (1 - np.square(np.array(y)))
>>>
>>> def debug_func(x):
>>> print(x)
>>>
>>> def simple_net(img, label):
>>> hidden = img
>>> for idx in six.moves.range(4):
>>> hidden = fluid.layers.fc(hidden, size=200)
>>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
>>> dtype=hidden.dtype, shape=hidden.shape)
>>>
>>> # user-defined layers with forward and backward
>>> hidden = fluid.layers.py_func(func=tanh, x=hidden,
>>> out=new_hidden, backward_func=tanh_grad,
>>> skip_vars_in_backward_input=hidden)
>>>
>>> # user-defined debug layers to print variables
>>> fluid.layers.py_func(func=debug_func, x=hidden, out=None)
>>>
>>> prediction = fluid.layers.fc(hidden, size=10, act='softmax')
>>> loss = fluid.layers.cross_entropy(input=prediction, label=label)
>>> return fluid.layers.mean(loss)
"""
helper = LayerHelper('py_func', **locals())
if x is None:
x = []
elif isinstance(x, Variable):
x = [x]
elif not isinstance(x, (list, tuple)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
if out is None:
out_list = []
elif isinstance(out, Variable):
out_list = [out]
elif isinstance(out, (list, tuple)):
out_list = out
else:
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')
fwd_func_id = PyFuncRegistry(func).id
bwd_func_id = PyFuncRegistry(
backward_func).id if backward_func is not None else -1
for each_out in out_list:
if len(each_out.shape) == 0:
raise ValueError(
'Output shapes of py_func op should be provided by users manually'
)
backward_skip_vars = set()
if backward_func is not None and skip_vars_in_backward_input is not None:
if isinstance(skip_vars_in_backward_input, Variable):
skip_vars_in_backward_input = [skip_vars_in_backward_input]
fwd_in_out = [v.name for v in x]
fwd_in_out.extend([v.name for v in out_list])
fwd_in_out = set(fwd_in_out)
backward_skip_vars = set()
for v in skip_vars_in_backward_input:
if not v.name in fwd_in_out:
raise ValueError(
'Variable {} is not found in forward inputs and outputs'
.format(v.name))
backward_skip_vars.add(v.name)
helper.append_op(
type='py_func',
inputs={'X': x},
outputs={'Out': out_list},
attrs={
'forward_callable_id': fwd_func_id,
'backward_callable_id': bwd_func_id,
'backward_skip_vars': list(backward_skip_vars)
})
return out
# For debug usage
py_func.registered_func = PyFuncRegistry.registered_func
py_func.registered_func_num = PyFuncRegistry.registered_func_num
@templatedoc()
def psroi_pool(input,
rois,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import paddle
import unittest
import six
import numpy as np
dev_cnt = 2
if fluid.core.is_compiled_with_cuda():
dev_cnt = fluid.core.get_cuda_device_count()
os.environ['CPU_NUM'] = str(dev_cnt)
def dummy_func_with_no_input():
return float(1.0)
def dummy_func_with_no_output(x):
pass
def tanh(x):
return np.tanh(x)
def tanh_grad(y, dy):
return np.array(dy) * (1 - np.square(np.array(y)))
def cross_entropy(logits, labels):
logits = np.array(logits)
labels = np.array(labels)
M = logits.shape[0]
N = logits.shape[1]
ret = np.ndarray([M, 1]).astype(logits.dtype)
for idx in six.moves.range(M):
ret[idx][0] = -np.log(logits[idx][labels[idx][0]])
return ret
def cross_entropy_grad(logits, labels, bwd_dout):
logits = np.array(logits)
labels = np.array(labels)
bwd_dout = np.array(bwd_dout)
M = logits.shape[0]
N = logits.shape[1]
dlogits = np.zeros([M, N]).astype(logits.dtype)
for idx in six.moves.range(M):
dlogits[idx][labels[idx][0]] = -bwd_dout[idx] / logits[idx][labels[idx][
0]]
return dlogits, None
def simple_fc_net(img, label, use_py_func_op):
hidden = img
for idx in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
if not use_py_func_op:
hidden = fluid.layers.tanh(hidden)
else:
new_hidden = fluid.default_main_program().current_block(
).create_var(
name='hidden_{}'.format(idx),
dtype='float32',
shape=hidden.shape)
hidden = fluid.layers.py_func(
func=tanh,
x=hidden,
out=new_hidden,
backward_func=tanh_grad,
skip_vars_in_backward_input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
if not use_py_func_op:
loss = fluid.layers.cross_entropy(input=prediction, label=label)
else:
loss = fluid.default_main_program().current_block().create_var(
name='loss', dtype='float32', shape=[-1, 1])
loss = fluid.layers.py_func(
func=cross_entropy,
x=[prediction, label],
out=loss,
backward_func=cross_entropy_grad,
skip_vars_in_backward_input=loss)
dummy_var = fluid.default_main_program().current_block().create_var(
name='test_tmp_var', dtype='float32', shape=[1])
fluid.layers.py_func(
func=dummy_func_with_no_input, x=None, out=dummy_var)
fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None)
loss = fluid.layers.mean(loss)
return loss
def reader():
for _ in six.moves.range(dev_cnt * 100):
yield np.random.random([784]), np.random.random_integers(
size=[1], low=0, high=9)
def test_main(use_cuda, use_py_func_op, use_parallel_executor):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return None
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.core.Scope()):
fluid.default_main_program().random_seed = 1
fluid.default_startup_program().random_seed = 1
np.random.seed(1)
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
loss = simple_fc_net(img, label, use_py_func_op)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
r = paddle.batch(reader, batch_size=10)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name)
fetch_list = [loss.name]
else:
fetch_list = [loss]
ret = []
for epoch_id in six.moves.range(2):
for d in r():
L, = exe.run(feed=feeder.feed(d), fetch_list=fetch_list)
ret.append(L)
return np.array(ret)
class TestPyFuncOpUseExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = False
def test_loss_diff(self):
losses = []
for use_cuda in [True, False]:
for use_py_func_op in [True, False]:
L = test_main(use_cuda, use_py_func_op,
self.use_parallel_executor)
if L is not None:
losses.append(L)
for idx in six.moves.range(len(losses) - 1):
max_diff = np.max(np.abs(losses[idx] - losses[0]))
self.assertAlmostEqual(max_diff, 0, delta=1e-3)
class TestPyFuncOpUseParallelExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册