提交 bf0c90f2 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_async_update_failed

......@@ -40,12 +40,12 @@ ExternalProject_Add(
# NOTE(wuyi):
# this package is generated by following steps:
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
# 2. submodule update --init
# 2. git submodule update --init
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
# checkout and clean other dirs under third_party
# 4. remove .git, and package the directory.
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
URL_MD5 "c9c58ee7d0e8929a63155af6a2ecdbd0"
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x.tar.gz"
URL_MD5 "1f268a2aff6759839dccd256adcc91cf"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
int op_dev_id = -1;
if (op.Type() == "split_byref") {
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
......
......@@ -70,6 +70,7 @@ $$Out = values$$
namespace ops = paddle::operators;
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker);
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>);
......@@ -269,14 +269,15 @@ void GRPCClient::Proceed() {
}
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}
// Channel configurations:
grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
......
......@@ -76,6 +76,7 @@ class BaseProcessor {
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true);
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
......@@ -85,6 +86,7 @@ class BaseProcessor {
virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());
context_->set_wait_for_ready(true);
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
......@@ -176,26 +178,24 @@ class GRPCClient : public RPCClient {
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_grpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_grpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_grpc_deadline) override;
void AsyncSendBatchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;
void AsyncSendFetchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;
void Wait() override;
......@@ -211,7 +211,7 @@ class GRPCClient : public RPCClient {
void Proceed();
void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out);
int64_t time_out = FLAGS_grpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
......
......@@ -97,7 +97,7 @@ class RequestSend final : public RequestBase {
void Process() override {
std::string varname = GetReqName();
VLOG(3) << "RequestSend var_name:" << varname;
VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
......@@ -132,7 +132,7 @@ class RequestGet final : public RequestBase {
void Process() override {
// proc request.
std::string varname = request_.varname();
VLOG(3) << "RequestGet " << varname;
VLOG(4) << "RequestGet " << varname;
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
......@@ -178,7 +178,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
auto scope = request_->GetMutableLocalScope();
......@@ -201,10 +201,10 @@ class RequestPrefetch final : public RequestBase {
};
void AsyncGRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
VLOG(4) << "AsyncGRPCServer WaitSeverReady";
}
void AsyncGRPCServer::StartServer() {
......@@ -243,7 +243,7 @@ void AsyncGRPCServer::StartServer() {
for (int i = 0; i < threadnum; i++) {
rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
&AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
VLOG(3) << t.first << " creates threads!";
VLOG(4) << t.first << " creates threads!";
}
}
......@@ -260,7 +260,7 @@ void AsyncGRPCServer::StartServer() {
auto& threads = t.second;
for (size_t i = 0; i < threads.size(); ++i) {
threads[i]->join();
VLOG(3) << t.first << " threads ends!";
VLOG(4) << t.first << " threads ends!";
}
}
}
......@@ -268,7 +268,7 @@ void AsyncGRPCServer::StartServer() {
void AsyncGRPCServer::ShutdownQueue() {
for (auto& t : rpc_cq_) {
t.second->Shutdown();
VLOG(3) << t.first << " shutdown!";
VLOG(4) << t.first << " queue shutdown!";
}
}
......@@ -277,7 +277,7 @@ void AsyncGRPCServer::ShutDownImpl() {
is_shut_down_ = true;
ShutdownQueue();
VLOG(3) << "server_ shutdown!";
VLOG(4) << "server_ shutdown!";
server_->Shutdown();
}
......@@ -285,7 +285,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
LOG(WARNING) << "shutdown, do not TryToRegisterNewSendOne";
VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
......
......@@ -13,6 +13,10 @@
// limitations under the License.
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "gflags/gflags.h"
// default to 3min to avoid temprary network failures.
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc");
namespace paddle {
namespace operators {
......
......@@ -15,11 +15,14 @@
#pragma once
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
DECLARE_int32(grpc_deadline);
namespace paddle {
namespace operators {
namespace distributed {
......@@ -32,26 +35,26 @@ class RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = rpc_time_out) = 0;
int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
......@@ -60,8 +63,6 @@ class RPCClient {
virtual void Wait() = 0;
static constexpr int64_t rpc_time_out = 120 * 1000;
template <typename T>
static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>);
......
......@@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
VLOG(3) << "batch_barrier_: " << rpc_name << " "
<< barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
......@@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
}
void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name;
VLOG(4) << "RPCServer WaitCond " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
......
......@@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop(
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
......@@ -207,7 +206,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
VLOG(3) << "RunAsyncLoop into while";
while (true) {
if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!";
......
......@@ -37,6 +37,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("SeedOut", "The random seed after random cropping.")
.AsIntermediate();
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
AddAttr<int>("startup_seed",
"If the input 'Seed' is not initialized, the 'startup_seed' "
"will be used to replace it. Even so, the seed after random "
"crop will also be outputed to the 'SeedOut'.")
.SetDefault(0);
AddComment(R"DOC(
This operator takes a batch of instance, and do random cropping on each instance.
It means that cropping positions differs on each instance, which is determined
......@@ -49,8 +54,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
class RandomCropOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
auto seed_dim = ctx->GetInputDim("Seed");
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
......@@ -62,7 +65,6 @@ class RandomCropOpInferShape : public framework::InferShapeBase {
out_dim[x_i] = shape[shape_i];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
}
};
......
......@@ -142,16 +142,22 @@ template <typename DeviceContext, typename T>
class RandomCropKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
int64_t seed = 0;
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
if (seed_tensor.IsInitialized()) {
if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>();
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
}
} else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program";
framework::LoDTensor cpu_seed;
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
seed = *cpu_seed.data<int64_t>();
VLOG(5) << "WARNING: The input 'Seed' is not initialized, use attribute "
"'startup_seed' instead.";
seed = ctx.Attr<int>("startup_seed");
}
auto shape = ctx.Attr<std::vector<int>>("shape");
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
......@@ -171,7 +177,7 @@ class RandomCropKernel : public framework::OpKernel<T> {
engine.discard(functor.prod_batchsize_dims_ *
(functor.rank_ - functor.num_batchsize_dims_));
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
platform::CPUPlace()) = engine();
framework::make_ddim({1}), platform::CPUPlace()) = engine();
}
};
......
......@@ -39,6 +39,7 @@ class CustomReader : public framework::DecoratedReader {
const framework::ProgramDesc program_;
int sub_block_id_;
framework::Executor exe_;
framework::Scope scope_;
std::vector<std::string> source_var_names_;
std::vector<std::string> sink_var_names_;
......@@ -158,23 +159,24 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
framework::Scope scope;
framework::Scope* exe_scope = &scope_.NewScope();
// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope.Var(source_var_names_[i]);
framework::Variable* var = exe_scope->Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
exe_.Run(program_, &scope, sub_block_id_, false, true);
exe_.Run(program_, exe_scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
const auto& tensor = detail::Ref(exe_scope->FindVar(sink_var_names_[i]))
.Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
}
scope_.DeleteScope(exe_scope);
}
} // namespace reader
......
......@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 3;
static constexpr size_t kCacheSize = 5;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
......
......@@ -559,19 +559,8 @@ class Operator(object):
self.attrs[attr_name] is None):
continue
attr_val = self.attrs[attr_name]
if isinstance(attr_val, Block):
self.desc.set_block_attr(attr_name,
self.attrs[attr_name].desc)
elif isinstance(attr_val, list) and attr_val and \
all(isinstance(v, Block) for v in attr_val):
self.desc.set_blocks_attr(attr_name,
[v.desc for v in attr_val])
elif isinstance(attr_val, core.BlockDesc) or \
isinstance(attr_val, core.ProgramDesc):
self.desc.set_serialized_attr(
attr_name, attr_val.serialize_to_string())
else:
self.desc.set_attr(attr_name, attr_val)
self._update_desc_attr(attr_name, attr_val)
self.desc.check_attrs()
if self.has_kernel(type):
self.desc.infer_var_type(self.block.desc)
......@@ -718,6 +707,19 @@ class Operator(object):
ValueError: If the type of value doesn't match with desc.attr_type(name).
"""
self.attrs[name] = val
self._update_desc_attr(name, val)
def _update_desc_attr(self, name, val):
"""
Update the value of desc's attribute by attribute's name.
Args:
name(str): the attribute name.
val(bool|int|str|float|list): the value of the attribute.
Raises:
ValueError: If the type of value doesn't match with desc.attr_type(name).
"""
if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
......
......@@ -110,7 +110,7 @@ class BlockGuardServ(BlockGuard):
class ListenAndServ(object):
"""
**ListenAndServ Layer**
ListenAndServ is used to create a rpc server bind and listen
on specific TCP port, this server will run the sub-block when
received variables from clients.
......@@ -212,7 +212,7 @@ def Send(endpoints, send_vars, sync=True):
of send_vars to send
send_vars (list): variables to send to server
sync (bool): whether to wait the request finish
"""
assert (type(send_vars) == list)
......@@ -469,10 +469,13 @@ def open_files(filenames,
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
buffer_size(int|None): The size of prefetch buffer. If it is setted None,
buffer size will be thread_num * 3.
Default: None
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Default: True
Returns:
Variable: A Reader Variable via which we can get file data.
......@@ -492,7 +495,7 @@ def open_files(filenames,
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num
buffer_size = thread_num * 3
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
......
......@@ -23,6 +23,7 @@ from layer_function_generator import autodoc, templatedoc
from tensor import concat
import utils
import random
from .. import unique_name
__all__ = [
'fc',
......@@ -4266,14 +4267,18 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
say :attr:`actual_shape` has a higher priority
than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, a new output tensor is created
whose data is copied from input x, otherwise the output
shares data with input without copying.
inplace(bool): If this flag is set true, the output
shares data with input without copying, otherwise
a new output tensor is created
whose data is copied from input x.
name (str): The name of this layer. It is optional.
Returns:
Variable: The output tensor.
Raises:
TypeError: if actual_shape is neither Variable nor None.
Examples:
.. code-block:: python
......@@ -4285,6 +4290,11 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python lsit or tuple.")
inputs = {"X": x}
if isinstance(actual_shape, Variable):
inputs["Shape"] = actual_shape
elif actual_shape is not None:
raise TypeError("actual_shape should either be Variable or None")
# Validate the shape
unk_dim_idx = -1
......@@ -4305,9 +4315,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reshape",
inputs={"X": x,
"Shape": actual_shape}
if isinstance(actual_shape, Variable) else {"X": x},
inputs=inputs,
attrs={"shape": shape,
"inplace": inplace},
outputs={"Out": reshaped})
......@@ -4889,47 +4897,39 @@ def random_crop(x, shape, seed=None):
>>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224])
"""
helper = LayerHelper("random_crop", **locals())
dtype = helper.input_dtype()
dtype = x.dtype
out = helper.create_tmp_variable(dtype)
if seed is None:
seed = random.randint(-65536, 65535)
op_attrs = {"shape": shape}
if isinstance(seed, int):
seed_value = seed
seed = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="fill_constant",
inputs={},
outputs={"Out": seed},
attrs={
"dtype": seed.dtype,
"shape": [1],
"value": float(seed_value),
"force_cpu": True
})
op_attrs["startup_seed"] = seed
seed = helper.create_variable(
name=unique_name.generate("random_crop_seed"),
dtype="int64",
persistable=True)
elif not isinstance(seed, Variable):
raise ValueError("'seed' must be a Variable or an int.")
seed_out = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="random_crop",
inputs={"X": x,
"Seed": seed},
outputs={"Out": out,
"SeedOut": seed_out},
attrs={"shape": shape})
"SeedOut": seed},
attrs=op_attrs)
return out
def log(input):
def log(x):
"""
Calculates the natural log of the given input tensor, element-wise.
.. math::
Out = \\ln(input)
Out = \\ln(x)
Args:
input (Variable): Input tensor.
x (Variable): Input tensor.
Returns:
Variable: The natural log of the input tensor computed element-wise.
......@@ -4938,7 +4938,7 @@ def log(input):
.. code-block:: python
output = fluid.layers.log(input)
output = fluid.layers.log(x)
"""
helper = LayerHelper('log', **locals())
dtype = helper.input_dtype(input_param_name='x')
......@@ -4947,18 +4947,18 @@ def log(input):
return out
def relu(input):
def relu(x):
"""
Relu takes one input data (Tensor) and produces one output data (Tensor)
where the rectified linear function, y = max(0, input), is applied to
where the rectified linear function, y = max(0, x), is applied to
the tensor elementwise.
.. math::
Out = \\max(0, input)
Out = \\max(0, x)
Args:
input (Variable): The input tensor.
x (Variable): The input tensor.
Returns:
Variable: The output tensor with the same shape as input.
......@@ -4967,7 +4967,7 @@ def relu(input):
.. code-block:: python
output = fluid.layers.relu(input)
output = fluid.layers.relu(x)
"""
helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x')
......
......@@ -155,7 +155,7 @@ def cast(x, dtype):
Examples:
.. code-block:: python
data = fluid.layers.data(name='x', shape=[13], dtype='float32')
result = fluid.layers.cast(x=data, dtype='float64')
"""
......@@ -188,7 +188,7 @@ def concat(input, axis=0, name=None):
Examples:
.. code-block:: python
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
"""
helper = LayerHelper('concat', **locals())
......@@ -238,7 +238,7 @@ def sums(input, out=None):
return out
def assign(input, output):
def assign(input, output=None):
"""
**Assign**
......@@ -246,7 +246,7 @@ def assign(input, output):
Args:
input(Variable|numpy.ndarray): The source variable
output(Variable): The destination variable
output(Variable|None): The destination variable
Returns:
Variable: The destination variable that was supplied as the *output*.
......@@ -259,6 +259,8 @@ def assign(input, output):
fluid.layers.assign(hidden, out)
"""
helper = LayerHelper('assign', **locals())
if output is None:
output = helper.create_tmp_variable(dtype=input.dtype)
if isinstance(input, Variable):
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
......
......@@ -596,12 +596,12 @@ class Auc(MetricBase):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if predictions[i, 1] >= thresh:
if preds[i, 1] >= thresh:
tp += 1
else:
fn += 1
else:
if predictions[i, 1] >= thresh:
if preds[i, 1] >= thresh:
fp += 1
else:
tn += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册