未验证 提交 16dfedb8 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14103 from jacquesqiao/cpu-for-1.1-merge-with-shape

[1.1] Cpu for 1.1 merge with shape
...@@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { ...@@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
case proto::AttrType::LONG: { case proto::AttrType::LONG: {
return attr_desc.l(); return attr_desc.l();
} }
case proto::AttrType::LONGS: {
std::vector<int64_t> val(attr_desc.longs_size());
for (int i = 0; i < attr_desc.longs_size(); ++i) {
val[i] = attr_desc.longs(i);
}
return val;
}
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
} }
......
...@@ -26,6 +26,113 @@ limitations under the License. */ ...@@ -26,6 +26,113 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<std::vector<int64_t>> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
std::vector<int64_t>* operator()(Attribute& attr) const {
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
std::vector<int> val = boost::get<std::vector<int>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
std::vector<float> val = boost::get<std::vector<float>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
}
std::vector<int64_t>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
...@@ -42,7 +149,11 @@ class AttrReader { ...@@ -42,7 +149,11 @@ class AttrReader {
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name); name);
return boost::get<T>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
} }
private: private:
...@@ -82,7 +193,7 @@ class DefaultValueSetter { ...@@ -82,7 +193,7 @@ class DefaultValueSetter {
public: public:
explicit DefaultValueSetter(T default_value) explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {} : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; } void operator()(T& value) const { value = default_value_; } // NOLINT
private: private:
T default_value_; T default_value_;
...@@ -117,84 +228,6 @@ class EnumInContainer { ...@@ -117,84 +228,6 @@ class EnumInContainer {
std::unordered_set<T> container_; std::unordered_set<T> container_;
}; };
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
...@@ -235,7 +268,7 @@ class TypedAttrChecker { ...@@ -235,7 +268,7 @@ class TypedAttrChecker {
return *this; return *this;
} }
void operator()(AttributeMap& attr_map) const { void operator()(AttributeMap& attr_map) const { // NOLINT
if (!attr_map.count(attr_name_)) { if (!attr_map.count(attr_name_)) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE(!default_value_setter_.empty(),
...@@ -271,7 +304,7 @@ class OpAttrChecker { ...@@ -271,7 +304,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>()); return *(checker.target<TypedAttrChecker<T>>());
} }
void Check(AttributeMap& attr_map) const { void Check(AttributeMap& attr_map) const { // NOLINT
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(attr_map); checker(attr_map);
} }
......
...@@ -59,6 +59,10 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -59,6 +59,10 @@ void BroadcastOpHandle::BroadcastOneVar(
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (UNLIKELY(!in_tensor.IsInitialized())) {
VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!";
return;
}
InitOutputValue(in_var_handle, out_var_handles); InitOutputValue(in_var_handle, out_var_handles);
......
...@@ -722,7 +722,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -722,7 +722,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} }
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
......
...@@ -35,6 +35,7 @@ enum AttrType { ...@@ -35,6 +35,7 @@ enum AttrType {
BLOCK = 8; BLOCK = 8;
LONG = 9; LONG = 9;
BLOCKS = 10; BLOCKS = 10;
LONGS = 11;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -55,6 +56,7 @@ message OpDesc { ...@@ -55,6 +56,7 @@ message OpDesc {
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14; repeated int32 blocks_idx = 14;
repeated int64 longs = 15;
}; };
message Var { message Var {
......
...@@ -419,8 +419,15 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -419,8 +419,15 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
} }
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx()); VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(const std::vector<int64_t> &v) const {
VectorToRepeated(v, attr_->mutable_longs());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
......
...@@ -187,6 +187,10 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -187,6 +187,10 @@ void ParallelExecutor::BCastParamsToDevices(
} }
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
if (!main_tensor.IsInitialized()) {
VLOG(3) << "one in var not inited, return!";
continue;
}
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -36,7 +36,7 @@ using Attribute = ...@@ -36,7 +36,7 @@ using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t, std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>>; std::vector<BlockDesc*>, std::vector<int64_t>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class FakeInitInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeInitOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape));
}
};
class FakeInitOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
framework::Tensor *tensor = nullptr;
auto &out_var = *scope.FindVar(Output("Out"));
if (out_var.IsType<framework::LoDTensor>()) {
tensor = out_var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else if (out_var.IsType<framework::SelectedRows>()) {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else {
PADDLE_THROW(
"fake init op's output only"
"supports SelectedRows and LoDTensor");
}
}
};
class FakeInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
};
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddComment(R"DOC(
FakeInit Operator.
Init an variable but not alloc memory for it, it is used for init the
table parameter at trainer side in distributed lookup table.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fake_init, ops::FakeInitOp, ops::FakeInitInferShape,
ops::FakeInitOpMaker, paddle::framework::EmptyGradOpMaker,
ops::FakeInitOpVarTypeInference);
...@@ -24,7 +24,7 @@ class FillConstantInferShape : public framework::InferShapeBase { ...@@ -24,7 +24,7 @@ class FillConstantInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null."); "Output(Out) of FillConstantOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape)); ctx->SetOutputDim("Out", framework::make_ddim(shape));
} }
}; };
...@@ -47,10 +47,10 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -47,10 +47,10 @@ class FillConstantOp : public framework::OperatorBase {
if (out_var.IsType<framework::LoDTensor>()) { if (out_var.IsType<framework::LoDTensor>()) {
tensor = out_var.GetMutable<framework::LoDTensor>(); tensor = out_var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else if (out_var.IsType<framework::SelectedRows>()) { } else if (out_var.IsType<framework::SelectedRows>()) {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"fill constant op's output only" "fill constant op's output only"
...@@ -83,7 +83,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,7 +83,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output"); AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
......
...@@ -52,7 +52,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GaussianRandomOp should not be null."); "Output(Out) of GaussianRandomOp should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(shape.size()); temp.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
...@@ -88,8 +88,8 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,8 +88,8 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddOutput("Out", "Output matrix of gaussian random op"); AddOutput("Out", "Output matrix of gaussian random op");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int64_t>>("shape",
"(vector<int>) " "(vector<int64_t>) "
"The dimension of random tensor."); "The dimension of random tensor.");
AddAttr<float>("mean", AddAttr<float>("mean",
"(float, default 0.0) " "(float, default 0.0) "
......
...@@ -27,6 +27,10 @@ limitations under the License. */ ...@@ -27,6 +27,10 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -332,11 +336,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -332,11 +336,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sync_mode, checkpoint_block_id)); sync_mode, checkpoint_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get()); request_send_handler_.get(),
FLAGS_rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestGet, rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get()); request_get_handler_.get(),
FLAGS_rpc_get_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestPrefetch, rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get(),
FLAGS_rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get()); request_checkpoint_handler_.get());
......
...@@ -121,7 +121,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -121,7 +121,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <map>
#include <set> #include <set>
#include <vector> #include <unordered_map>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -230,8 +229,24 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>; ...@@ -230,8 +229,24 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
// add or mul. // add or mul.
namespace scatter { namespace scatter {
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) { template <typename DeviceContext, typename T>
return std::find(rows.begin(), rows.end(), value) - rows.begin(); typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) {
blas->AXPY(data_len, 1., in, out);
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) {
for (int64_t i = 0; i < data_len; i++) {
out[i] += in[i];
}
} }
template <typename T> template <typename T>
...@@ -246,48 +261,84 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -246,48 +261,84 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output) {
framework::SelectedRows& out = *output; std::vector<const framework::SelectedRows*> inputs;
std::vector<int64_t> input_rows(input.rows()); inputs.push_back(&input);
(*this)(context, inputs, output);
std::map<int64_t, std::vector<int64_t>> merge_row_map;
for (size_t i = 0; i < input_rows.size(); ++i) {
merge_row_map[input_rows[i]].push_back(i);
} }
std::vector<int64_t> merge_rows(merge_row_map.size()); void operator()(const platform::CPUDeviceContext& context,
size_t idx = 0; const std::vector<const framework::SelectedRows*>& inputs,
int64_t input_width = input.value().dims()[1]; framework::SelectedRows* output) {
out.set_height(input.height()); if (inputs.size() == 0) {
VLOG(3) << "no input! return";
T* out_data = out.mutable_value()->mutable_data<T>( return;
}
const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) {
if (in->rows().size() > 0) {
has_value_input = in;
break;
}
}
if (has_value_input == nullptr) {
VLOG(3) << "no input has value! just return" << std::endl;
return;
}
auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height();
framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set;
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
"all input should have same "
"dimension except for the first one");
PADDLE_ENFORCE_EQ(input_height, input->height(),
"all input should have same height");
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}
out.set_rows(merge_rows);
out.set_height(input_height);
out.mutable_value()->mutable_data<T>(
framework::make_ddim( framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}), {static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace()); context.GetPlace());
const T* in_data = input.value().data<T>();
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
for (auto& row_pair : merge_row_map) { constant_functor(context, out.mutable_value(), 0.0);
auto* out_ptr = out_data + idx * input_width;
auto& rows = row_pair.second; auto* out_data = out.mutable_value()->data<T>();
merge_rows[idx] = row_pair.first;
++idx; auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
// rows.size() is always larger than 0 for (auto* input : inputs) {
std::memcpy(out_ptr, in_data + rows[0] * input_width, if (input->rows().size() == 0) {
sizeof(T) * input_width); continue;
for (size_t i = 1; i < rows.size(); ++i) {
auto* in_ptr = in_data + rows[i] * input_width;
for (int64_t j = 0; j < input_width; ++j) {
out_ptr[j] += in_ptr[j];
} }
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
elementwise_add_to<platform::CPUDeviceContext, T>(
context, &blas, static_cast<size_t>(input_width),
&input_data[i * input_width], &out_data[out_i * input_width]);
} }
} }
out.set_rows(merge_rows);
} }
}; };
template struct MergeAdd<platform::CPUDeviceContext, int>; template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>; template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template <typename T> template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> { struct UpdateToTensor<platform::CPUDeviceContext, T> {
......
...@@ -267,10 +267,15 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -267,10 +267,15 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
}
framework::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
auto input_width = input.value().dims()[1]; auto input_width = input.value().dims()[1];
...@@ -296,6 +301,73 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -296,6 +301,73 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width); out.rows().size(), input_width);
} }
void operator()(const platform::CUDADeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) {
if (inputs.size() == 0) {
VLOG(3) << "no input! return";
return;
}
const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) {
if (in->rows().size() > 0) {
has_value_input = in;
break;
}
}
if (has_value_input == nullptr) {
VLOG(3) << "no input has value! just return" << std::endl;
return;
}
auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height();
framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set;
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
"all input should have same "
"dimension except for the first one");
PADDLE_ENFORCE_EQ(input_height, input->height(),
"all input should have same height");
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
merged_row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
out.set_rows(merge_rows);
out.set_height(input_height);
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
dim3 grid1(input_rows.size(), 1);
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width);
}
}
}; };
template struct MergeAdd<platform::CUDADeviceContext, float>; template struct MergeAdd<platform::CUDADeviceContext, float>;
......
...@@ -83,104 +83,9 @@ struct MergeAdd { ...@@ -83,104 +83,9 @@ struct MergeAdd {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output); framework::SelectedRows* output);
}; void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
template <> framework::SelectedRows* output);
struct MergeAdd<platform::CPUDeviceContext, float> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
std::vector<int64_t> input_rows(input.rows());
std::map<int64_t, std::vector<int64_t>> merge_row_map;
for (size_t i = 0; i < input_rows.size(); ++i) {
merge_row_map[input_rows[i]].push_back(i);
}
std::vector<int64_t> merge_rows(merge_row_map.size());
size_t idx = 0;
int64_t input_width = input.value().dims()[1];
out.set_height(input.height());
auto* out_data = out.mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
auto* in_data = input.value().data<float>();
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
for (auto& row_pair : merge_row_map) {
auto* out_ptr = out_data + idx * input_width;
auto& rows = row_pair.second;
merge_rows[idx] = row_pair.first;
++idx;
// rows.size() is always larger than 0
blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr);
for (size_t i = 1; i < rows.size(); ++i) {
blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr);
}
}
out.set_rows(merge_rows);
}
};
template <>
struct MergeAdd<platform::CPUDeviceContext, double> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
std::vector<int64_t> input_rows(input.rows());
std::map<int64_t, std::vector<int64_t>> merge_row_map;
for (size_t i = 0; i < input_rows.size(); ++i) {
merge_row_map[input_rows[i]].push_back(i);
}
std::vector<int64_t> merge_rows(merge_row_map.size());
size_t idx = 0;
int64_t input_width = input.value().dims()[1];
out.set_height(input.height());
auto* out_data = out.mutable_value()->mutable_data<double>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
auto* in_data = input.value().data<double>();
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
for (auto& row_pair : merge_row_map) {
auto* out_ptr = out_data + idx * input_width;
auto& rows = row_pair.second;
merge_rows[idx] = row_pair.first;
++idx;
// rows.size() is always larger than 0
blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr);
for (size_t i = 1; i < rows.size(); ++i) {
blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr);
}
}
out.set_rows(merge_rows);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -302,6 +302,64 @@ TEST(selected_rows_functor, cpu_merge_add_int) { ...@@ -302,6 +302,64 @@ TEST(selected_rows_functor, cpu_merge_add_int) {
EXPECT_EQ(out_data[1 * row_numel], 2); EXPECT_EQ(out_data[1 * row_numel], 2);
EXPECT_EQ(out_data[2 * row_numel], 1); EXPECT_EQ(out_data[2 * row_numel], 1);
} }
TEST(selected_rows_functor, cpu_merge_add_multi) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
float>
set_const;
int64_t height = 10;
int64_t row_numel = 8;
std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place);
set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place);
set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float>
merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get());
EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(),
paddle::framework::make_ddim({3, row_numel}));
std::vector<int64_t> ret_rows{2, 3, 5};
EXPECT_EQ(output->rows(), ret_rows);
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
}
}
}
TEST(selected_rows_functor, cpu_sum_to) { TEST(selected_rows_functor, cpu_sum_to) {
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place); paddle::platform::CPUDeviceContext ctx(cpu_place);
...@@ -318,6 +376,7 @@ TEST(selected_rows_functor, cpu_sum_to) { ...@@ -318,6 +376,7 @@ TEST(selected_rows_functor, cpu_sum_to) {
paddle::framework::make_ddim( paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}), {static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place); cpu_place);
functor(ctx, in1_value, 1.0); functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{ std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
...@@ -327,6 +386,7 @@ TEST(selected_rows_functor, cpu_sum_to) { ...@@ -327,6 +386,7 @@ TEST(selected_rows_functor, cpu_sum_to) {
paddle::framework::make_ddim( paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}), {static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place); cpu_place);
functor(ctx, in2_value, 2.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()}; new paddle::framework::SelectedRows()};
......
...@@ -241,3 +241,67 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -241,3 +241,67 @@ TEST(selected_rows_functor, gpu_add_to) {
// row9: 2.0 + 3.0 // row9: 2.0 + 3.0
EXPECT_EQ(tensor1_cpu_data[9 * row_numel + 6], 5.0); EXPECT_EQ(tensor1_cpu_data[9 * row_numel + 6], 5.0);
} }
TEST(selected_rows_functor, gpu_merge_add) {
paddle::platform::CUDAPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place;
paddle::platform::CUDADeviceContext& ctx =
*reinterpret_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
float>
set_const;
int64_t height = 10;
int64_t row_numel = 8;
std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
gpu_place);
set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
gpu_place);
set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
paddle::operators::math::scatter::MergeAdd<
paddle::platform::CUDADeviceContext, float>
merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get());
paddle::framework::Tensor output_cpu;
paddle::framework::TensorCopy(output->value(), cpu_place, ctx, &output_cpu);
ctx.Wait();
EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(),
paddle::framework::make_ddim({3, row_numel}));
std::vector<int64_t> ret_rows{2, 3, 5};
EXPECT_EQ(output->rows(), ret_rows);
auto* out_data = output_cpu.data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
}
}
}
...@@ -20,13 +20,16 @@ namespace operators { ...@@ -20,13 +20,16 @@ namespace operators {
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
AddInput( .AsDuplicable();
"X", AddInput("Rows", "(LoDTensor) the input ids with shape{row_size, 1}, ")
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the " .AsDuplicable();
AddInput("X",
"(LoDTensors) multi input tensor with shape{Rows, N}, N is the "
"size of embedding table") "size of embedding table")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num. Merge multi LoDTensor's into one according to Ids's shard num.
...@@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids."); PADDLE_ENFORCE(ctx->HasInputs("Ids"),
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X."); "MergeIdsOp must has multi input Ids.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out."); PADDLE_ENFORCE(ctx->HasInputs("Rows"),
"MergeIdsOp must has multi input Rows.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has multi input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"MergeIdsOp must has multi output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims[0][1], 1);
} }
auto x_var_type = ctx->GetInputsVarType("X"); auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) { for (auto &var_type : x_var_type) {
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <tuple>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
if (!platform::is_cpu_place(place)) { if (!platform::is_cpu_place(place)) {
PADDLE_THROW("MergeIds do not support GPU kernel"); PADDLE_THROW("MergeIds do not support GPU kernel");
} }
VLOG(3) << "run in MergeIdsOpKernel";
const auto *ids_var = ctx.InputVar("Ids"); const auto ids = ctx.MultiInput<framework::LoDTensor>("Ids");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(), const auto row_ids = ctx.MultiInput<framework::LoDTensor>("Rows");
"only support to merge Ids of LoDTensor"); const auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>(); PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(),
const auto &ids_dims = ids_tensor.dims(); "the number of Rows and X should be the same");
const int64_t *ids = ids_tensor.data<int64_t>(); PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same");
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X"); int row_ids_size = 0;
int row_size = 0;
int embedding_size = 0;
auto *out = ctx.Output<framework::LoDTensor>("Out"); for (int i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i];
const auto *row_id = row_ids[i];
int batch_size = 0;
int embedding_size = 0;
for (auto &input : x_tensors) {
if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) { if (embedding_size == 0) {
embedding_size = input->dims()[1]; embedding_size = x_tensor->dims()[1];
} }
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1],
"embedding size of all input should be the same"); "embedding size of all input should be the same");
batch_size += input->dims()[0]; row_size += x_tensor->dims()[0];
} row_ids_size += row_id->dims()[0];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0], row_size, row_ids_size,
"the batch size of ids and merged embedding value should be the same"); "the merged X dim[0] and merged Rows dim[0] should be the same");
std::unordered_map<int64_t, std::tuple<int64_t, int64_t>>
selected_rows_idx_map;
for (int i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i];
for (int j = 0; j < row_id->numel(); ++j) {
int64_t key = row_id->data<int64_t>()[j];
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
selected_rows_idx_map.insert(std::make_pair(key, val));
}
}
PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(),
"the rows and tensor map size should be the same");
for (int i = 0; i < outs.size(); ++i) {
auto *out_ids = ids[i];
auto *out = outs[i];
const size_t shard_num = x_tensors.size(); out->set_lod(out_ids->lod());
if (shard_num == 1) { int nums = static_cast<int>(out_ids->dims()[0]);
VLOG(3) << "only one shard, we can copy the data directly";
TensorCopy(*x_tensors[0], place, out);
} else {
std::vector<int> in_indexs(shard_num, 0);
auto *out_data = out->mutable_data<T>( auto *out_data = out->mutable_data<T>(
framework::make_ddim({batch_size, embedding_size}), place); framework::make_ddim({nums, embedding_size}), place);
// copy data from ins[shard_num] to out. for (int j = 0; j < nums; ++j) {
for (int i = 0; i < ids_dims[0]; ++i) { int id = out_ids->data<int64_t>()[j];
int64_t id = ids[i]; auto row_tuple = selected_rows_idx_map[id];
size_t shard_id = static_cast<size_t>(id) % shard_num; int64_t row_idx = std::get<1>(row_tuple);
int index = in_indexs[shard_id]; const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
memcpy(out_data + embedding_size * i,
x_tensors[shard_id]->data<T>() + index * embedding_size, memcpy(out_data + embedding_size * j,
x_tensor->data<T>() + row_idx * embedding_size,
sizeof(T) * embedding_size); sizeof(T) * embedding_size);
in_indexs[shard_id] += 1;
}
for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
} }
} }
} }
......
...@@ -20,17 +20,24 @@ namespace operators { ...@@ -20,17 +20,24 @@ namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") .AsDuplicable();
AddOutput("Out", "(LoDTensors) The outputs of the input Ids.")
.AsDuplicable(); .AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number
Example: Example:
Input: Input:
X = [1,2,3,4,5,6] X = [[1,2,3,4,5,6],[2,3]]
Out(3 output): Out(3 output):
if compress is True:
out0 = [3, 3, 6]
out1 = [1, 4]
out2 = [2, 2, 5]
else:
out0 = [3, 6] out0 = [3, 6]
out1 = [1, 4] out1 = [1, 4]
out2 = [2, 5] out2 = [2, 5]
...@@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids."); PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("Ids").front()->type()),
ctx.GetPlace());
}
}; };
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
...@@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference { ...@@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference {
} }
}; };
class SplitIdsOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad = new framework::OpDesc();
grad->SetType("concat");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("Ids"));
grad->SetAttr("axis", 0);
return std::unique_ptr<framework::OpDesc>(grad);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType); ops::SplitIdsOpGradMaker, ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>, split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>); ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <iterator>
#include <set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
PADDLE_THROW("SplitIds do not support GPU kernel"); PADDLE_THROW("SplitIds do not support GPU kernel");
} }
const auto *ids_var = ctx.InputVar("Ids"); const auto ids_vars = ctx.MultiInputVar("Ids");
PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0");
auto *ids_var = ids_vars[0];
if (ids_var->IsType<framework::LoDTensor>()) { if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims(); int batch_size = 0;
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>(); const auto ids_tensors = ctx.MultiInput<framework::LoDTensor>("Ids");
for (size_t i = 0; i < ids_tensors.size(); ++i) {
batch_size += ids_tensors[i]->dims()[0];
}
VLOG(4) << "Get Total BatchSize is: " << batch_size;
std::vector<T> all_ids(batch_size);
int offset = 0;
for (size_t i = 0; i < ids_tensors.size(); ++i) {
const auto *ids = ids_tensors[i];
std::memcpy(all_ids.data() + offset, ids->data<T>(),
ids->numel() * sizeof(T));
offset += ids->numel();
}
std::set<T> st(all_ids.begin(), all_ids.end());
all_ids.assign(st.begin(), st.end());
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out"); auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids; std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size()); out_ids.resize(outs.size());
// split id by their shard_num. // split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) { for (int i = 0; i < all_ids.size(); ++i) {
T id = ids[i]; T id = all_ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num; size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id); out_ids[shard_id].push_back(id);
} }
...@@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(ids_dims[0], PADDLE_ENFORCE_EQ(ids_dims[0],
static_cast<int64_t>(ids_selected_rows->rows().size()), static_cast<int64_t>(ids_selected_rows->rows().size()),
""); "");
const T *ids = ids_selected_rows->value().data<T>(); const T *ids_data = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows(); const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
...@@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
T *output = out->mutable_value()->mutable_data<T>(ddim, place); T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) { for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, memcpy(output + i * row_width,
ids + id_to_index[out->rows()[i]] * row_width, ids_data + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
......
...@@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "The input SelectedRows."); AddInput("X", "The input SelectedRows.");
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable(); AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
AddAttr<std::vector<int>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int>({})); .SetDefault(std::vector<int64_t>({}));
AddComment(R"DOC( AddComment(R"DOC(
Split a SelectedRows with a specified rows section. Split a SelectedRows with a specified rows section.
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static int FindOutIdx(int row, const std::vector<int>& abs_sections) { static int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
for (size_t i = 1; i < abs_sections.size(); ++i) { for (size_t i = 1; i < abs_sections.size(); ++i) {
if (row < abs_sections[i]) { if (row < abs_sections[i]) {
return i - 1; return i - 1;
...@@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) { ...@@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
return abs_sections.size() - 1; return abs_sections.size() - 1;
} }
static std::vector<int> ToAbsoluteSection( static std::vector<int64_t> ToAbsoluteSection(
const std::vector<int>& height_sections) { const std::vector<int64_t>& height_sections) {
std::vector<int> abs_sections; std::vector<int64_t> abs_sections;
abs_sections.resize(height_sections.size()); abs_sections.resize(height_sections.size());
abs_sections[0] = 0; abs_sections[0] = 0;
for (size_t i = 1; i < height_sections.size(); ++i) { for (size_t i = 1; i < height_sections.size(); ++i) {
...@@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> { ...@@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::SelectedRows>("X"); auto* x = ctx.Input<framework::SelectedRows>("X");
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
auto height_sections = ctx.Attr<std::vector<int>>("height_sections"); auto height_sections = ctx.Attr<std::vector<int64_t>>("height_sections");
auto abs_sections = ToAbsoluteSection(height_sections); auto abs_sections = ToAbsoluteSection(height_sections);
......
...@@ -83,79 +83,54 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -83,79 +83,54 @@ class SumKernel : public framework::OpKernel<T> {
} }
} }
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
std::unique_ptr<framework::SelectedRows> in0; if (in_place && in_vars.size() < 2) {
if (in_place) { return;
// If is in_place, we store the input[0] to in0
auto &in_sel0 = in_vars[0]->Get<SelectedRows>();
auto &rows = in_sel0.rows();
#ifdef PADDLE_WITH_CUDA
std::vector<int64_t> rows_in_cpu;
rows_in_cpu.reserve(rows.size());
for (auto item : rows) {
rows_in_cpu.push_back(item);
}
in0.reset(new framework::SelectedRows(rows_in_cpu, in_sel0.height()));
#else
in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
#endif
in0->mutable_value()->ShareDataWith(in_sel0.value());
} }
auto get_selected_row = [&](size_t i) -> const SelectedRows & { std::vector<const paddle::framework::SelectedRows *> inputs;
if (i == 0 && in0) { SelectedRows temp_in0;
return *in0.get();
if (in_place) {
auto &in0 = in_vars[0]->Get<SelectedRows>();
temp_in0.set_height(in0.height());
temp_in0.set_rows(in0.rows());
framework::TensorCopy(in0.value(), in0.place(),
context.device_context(),
temp_in0.mutable_value());
inputs.push_back(&temp_in0);
for (size_t i = 1; i < in_vars.size(); ++i) {
auto &in = in_vars[i]->Get<SelectedRows>();
if (in.rows().size() > 0) {
inputs.push_back(&in);
}
}
} else { } else {
return in_vars[i]->Get<SelectedRows>(); for (auto &in_var : in_vars) {
auto &in = in_var->Get<SelectedRows>();
if (in.rows().size() > 0) {
inputs.push_back(&in_var->Get<SelectedRows>());
}
}
} }
};
auto *out = context.Output<SelectedRows>("Out"); auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear(); out->mutable_rows()->clear();
auto *out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
std::vector<int64_t> in_dim; bool has_data = false;
for (size_t i = 0; i < in_num; i++) { for (auto &in : inputs) {
auto &sel_row = get_selected_row(i); if (in->rows().size() > 0) {
if (sel_row.rows().size() > 0) { has_data = true;
in_dim = framework::vectorize(sel_row.value().dims());
break; break;
} }
} }
if (in_dim.empty()) { if (has_data) {
VLOG(3) << "WARNING: all the inputs are empty"; math::scatter::MergeAdd<DeviceContext, T> merge_add;
in_dim = merge_add(context.template device_context<DeviceContext>(), inputs,
framework::vectorize(get_selected_row(in_num - 1).value().dims()); out);
} else { } else {
in_dim[0] = static_cast<int64_t>(first_dim); // no data, just set a empty out tensor.
} out->mutable_value()->mutable_data<T>(framework::make_ddim({0}),
context.GetPlace());
out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(context.GetPlace());
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
}
math::SelectedRowsAddTo<DeviceContext, T> functor;
int64_t offset = 0;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(context.template device_context<DeviceContext>(), sel_row,
offset, out);
offset += sel_row.value().numel();
} }
} else if (out_var->IsType<framework::LoDTensorArray>()) { } else if (out_var->IsType<framework::LoDTensorArray>()) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>(); auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
......
...@@ -29,7 +29,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -29,7 +29,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>(); auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value(); tensor = selected_rows->mutable_value();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
...@@ -67,7 +67,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -67,7 +67,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"), ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max"); "uniform_random's min must less then max");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(shape.size()); temp.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
...@@ -94,7 +94,7 @@ This operator initializes a tensor with random values sampled from a ...@@ -94,7 +94,7 @@ This operator initializes a tensor with random values sampled from a
uniform distribution. The random result is in set [min, max]. uniform distribution. The random result is in set [min, max].
)DOC"); )DOC");
AddAttr<std::vector<int>>("shape", "The shape of the output tensor"); AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor");
AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].") AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].")
.SetDefault(-1.0f); .SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].") AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].")
......
...@@ -48,7 +48,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -48,7 +48,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = context.Attr<std::vector<int>>("shape"); auto shape = context.Attr<std::vector<int64_t>>("shape");
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
} else { } else {
......
...@@ -57,6 +57,18 @@ struct variant_caster<V<Ts...>> { ...@@ -57,6 +57,18 @@ struct variant_caster<V<Ts...>> {
auto caster = make_caster<T>(); auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) { if (!load_success_ && caster.load(src, convert)) {
load_success_ = true; load_success_ = true;
if (std::is_same<T, std::vector<float>>::value) {
auto caster_ints = make_caster<std::vector<int64_t>>();
if (caster_ints.load(src, convert)) {
VLOG(4) << "This value are floats and int64_ts satisfy "
"simultaneously, will set it's type to "
"std::vector<int64_t>";
value = cast_op<std::vector<int64_t>>(caster_ints);
return true;
}
}
value = cast_op<T>(caster); value = cast_op<T>(caster);
return true; return true;
} }
...@@ -259,6 +271,8 @@ void BindOpDesc(pybind11::module *m) { ...@@ -259,6 +271,8 @@ void BindOpDesc(pybind11::module *m) {
pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "") pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "")
.value("INT", pd::proto::AttrType::INT) .value("INT", pd::proto::AttrType::INT)
.value("INTS", pd::proto::AttrType::INTS) .value("INTS", pd::proto::AttrType::INTS)
.value("LONG", pd::proto::AttrType::LONG)
.value("LONGS", pd::proto::AttrType::LONGS)
.value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOAT", pd::proto::AttrType::FLOAT)
.value("FLOATS", pd::proto::AttrType::FLOATS) .value("FLOATS", pd::proto::AttrType::FLOATS)
.value("STRING", pd::proto::AttrType::STRING) .value("STRING", pd::proto::AttrType::STRING)
......
...@@ -121,6 +121,9 @@ def __bootstrap__(): ...@@ -121,6 +121,9 @@ def __bootstrap__():
read_env_flags.append('rpc_server_profile_period') read_env_flags.append('rpc_server_profile_period')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler') read_env_flags.append('enable_rpc_profiler')
read_env_flags.append('rpc_send_thread_num')
read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num')
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
read_env_flags += [ read_env_flags += [
......
...@@ -120,6 +120,8 @@ class OpDescCreationMethod(object): ...@@ -120,6 +120,8 @@ class OpDescCreationMethod(object):
new_attr.strings.extend(user_defined_attr) new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.BOOLEANS: elif attr.type == framework_pb2.BOOLEANS:
new_attr.bools.extend(user_defined_attr) new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS: elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr: for p in user_defined_attr:
pair = new_attr.int_pairs.add() pair = new_attr.int_pairs.add()
......
...@@ -18,6 +18,7 @@ import unittest ...@@ -18,6 +18,7 @@ import unittest
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
# FIXME(tangwei): sum op can not handle when inputs is empty.
class TestDistCTR2x2(TestDistBase): class TestDistCTR2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
......
...@@ -42,7 +42,6 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): ...@@ -42,7 +42,6 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
#FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self): def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
...@@ -93,7 +92,6 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -93,7 +92,6 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
# FIXME(tangwei): Learningrate variable is not created on pserver. # FIXME(tangwei): Learningrate variable is not created on pserver.
"""
class TestDistSimnetBow2x2LookupTableSync(TestDistBase): class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
...@@ -146,7 +144,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase): ...@@ -146,7 +144,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
delta=1e-5, delta=1e-5,
check_error_log=False, check_error_log=False,
need_envs=need_envs) need_envs=need_envs)
"""
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -480,7 +480,7 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -480,7 +480,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep) pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 6) self.assertEqual(len(pserver1.blocks), 5)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
...@@ -491,26 +491,32 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -491,26 +491,32 @@ class TestDistLookupTable(TestDistLookupTableBase):
# 3 prefetch -> lookup_sparse_table for data0 # 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[3].ops], self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["lookup_sparse_table"]) ["lookup_sparse_table"])
# 4 prefetch -> lookup_sparse_table for data1 # 4 save table
self.assertEqual([op.type for op in pserver1.blocks[4].ops], self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
["lookup_sparse_table"])
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer() trainer, trainer_startup = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', 'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv', 'send', 'send_barrier', 'recv', 'recv', 'fetch_barrier'
'fetch_barrier'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv',
'fetch_barrier', 'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestAsyncLocalLookupTable(TestDistLookupTableBase): class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def net_conf(self): def net_conf(self):
...@@ -553,7 +559,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -553,7 +559,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 6) self.assertEqual(len(pserver1.blocks), 5)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
...@@ -563,22 +569,19 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -563,22 +569,19 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 3 prefetch -> lookup_sparse_table for data0 # 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[3].ops], self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["lookup_sparse_table"]) ["lookup_sparse_table"])
# 4 prefetch -> lookup_sparse_table for data1 # 4 save table
self.assertEqual([op.type for op in pserver1.blocks[4].ops], self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
["lookup_sparse_table"])
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', 'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
'sum', 'split_ids', 'send', 'recv', 'recv' 'send', 'recv', 'recv'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestFakeInitOpSelectedRows(unittest.TestCase):
def check_with_place(self, place, is_selected_rows):
scope = core.Scope()
out_var_name = 'Out'
if is_selected_rows:
out_tensor = scope.var(out_var_name).get_selected_rows().get_tensor(
)
else:
out_tensor = scope.var(out_var_name).get_tensor()
var_shape = [4, 784]
# create and run fake_init_op
fake_init_op = Operator("fake_init", Out=out_var_name, shape=var_shape)
fake_init_op.run(scope, place)
self.assertEqual(var_shape, out_tensor._get_dims())
def test_fake_init_selected_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
for is_selected_rows in [True, False]:
self.check_with_place(place, is_selected_rows)
if __name__ == "__main__":
unittest.main()
...@@ -22,15 +22,28 @@ from op_test import OpTest ...@@ -22,15 +22,28 @@ from op_test import OpTest
class TestMergeIdsOp(OpTest): class TestMergeIdsOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "merge_ids" self.op_type = "merge_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') ids1 = np.array([[0], [2], [5], [6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32') ids2 = np.array([[0], [2], [2], [3]]).astype('int64')
x1 = np.array([]).astype('float32')
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6], rows1 = np.array([[0], [2]]).astype('int64')
[0.5, 0.6]]).astype('float32') rows2 = np.array([[3], [5]]).astype('int64')
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3], rows3 = np.array([[6]]).astype('int64')
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]} x0 = np.array([[0.1, 0.2], [0.2, 0.3]]).astype('float32')
self.outputs = {'Out': out} x1 = np.array([[0.3, 0.4], [0.4, 0.5]]).astype('float32')
x2 = np.array([[0.5, 0.6]]).astype('float32')
out1 = np.array(
[[0.1, 0.2], [0.2, 0.3], [0.4, 0.5], [0.5, 0.6]]).astype('float32')
out2 = np.array(
[[0.1, 0.2], [0.2, 0.3], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
self.inputs = {
'Ids': [('ids1', ids1), ('ids2', ids2)],
"Rows": [('rows1', rows1), ('rows2', rows2), ('rows3', rows3)],
"X": [('x0', x0), ('x1', x1), ('x2', x2)]
}
self.outputs = {'Out': [('out1', out1), ('out2', out2)]}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
...@@ -25,18 +25,21 @@ from paddle.fluid.op import Operator ...@@ -25,18 +25,21 @@ from paddle.fluid.op import Operator
class TestSplitIdsOp(OpTest): class TestSplitIdsOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "split_ids" self.op_type = "split_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') ids1 = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
ids2 = np.array([[6], [2], [3], [3], [5], [2], [6]]).astype('int64')
ids3 = np.array([[2], [2], [2], [3], [5], [5], [6]]).astype('int64')
out0 = np.array([[0], [3], [6]]).astype('int64') out0 = np.array([[0], [3], [6]]).astype('int64')
out1 = np.array([[]]).astype('int64') out1 = np.array([[]]).astype('int64')
out2 = np.array([[2], [2], [5], [5]]).astype('int64') out2 = np.array([[2], [5]]).astype('int64')
self.inputs = {'Ids': ids} self.inputs = {'Ids': [('ids1', ids1), ('ids2', ids2), ('ids3', ids3)]}
self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]} self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestSpliteIds(unittest.TestCase): class TestSplitSelectedRows(unittest.TestCase):
def get_places(self): def get_places(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
return places return places
......
...@@ -99,7 +99,6 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -99,7 +99,6 @@ class TestSpliteSelectedRows(unittest.TestCase):
out0_grad.set_height(height) out0_grad.set_height(height)
out0_grad_tensor = out0_grad.get_tensor() out0_grad_tensor = out0_grad.get_tensor()
np_array = np.ones((len(rows0), row_numel)).astype("float32") np_array = np.ones((len(rows0), row_numel)).astype("float32")
np_array[0, 0] = 2.0
out0_grad_tensor.set(np_array, place) out0_grad_tensor.set(np_array, place)
out1_grad = scope.var("out1@GRAD").get_selected_rows() out1_grad = scope.var("out1@GRAD").get_selected_rows()
...@@ -108,7 +107,6 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -108,7 +107,6 @@ class TestSpliteSelectedRows(unittest.TestCase):
out1_grad.set_height(height) out1_grad.set_height(height)
out1_grad_tensor = out1_grad.get_tensor() out1_grad_tensor = out1_grad.get_tensor()
np_array = np.ones((len(rows1), row_numel)).astype("float32") np_array = np.ones((len(rows1), row_numel)).astype("float32")
np_array[0, 1] = 4.0
out1_grad_tensor.set(np_array, place) out1_grad_tensor.set(np_array, place)
x_grad = scope.var("X@GRAD").get_selected_rows() x_grad = scope.var("X@GRAD").get_selected_rows()
...@@ -121,11 +119,13 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -121,11 +119,13 @@ class TestSpliteSelectedRows(unittest.TestCase):
grad_op.run(scope, place) grad_op.run(scope, place)
self.assertEqual(x_grad.rows(), rows0 + rows1) merged_rows = set(rows0 + rows1)
self.assertEqual(set(x_grad.rows()), set(rows0 + rows1))
self.assertEqual(x_grad.height(), height) self.assertEqual(x_grad.height(), height)
print(np.array(x_grad.get_tensor()))
self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0]) self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0])
self.assertAlmostEqual(4.0, np.array(x_grad.get_tensor())[2, 1]) self.assertAlmostEqual(1.0, np.array(x_grad.get_tensor())[2, 1])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -45,16 +45,30 @@ class TestSumOp(OpTest): ...@@ -45,16 +45,30 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest): class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place): def check_with_place(self, place, inplace):
scope = core.Scope() self.height = 10
self.check_input_and_optput(scope, place, True, True, True) self.row_numel = 12
self.check_input_and_optput(scope, place, False, True, True) self.rows = [0, 1, 2, 3, 4, 5, 6]
self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False) self.check_input_and_optput(core.Scope(), place, inplace, True, True,
True)
self.check_input_and_optput(core.Scope(), place, inplace, False, True,
True)
self.check_input_and_optput(core.Scope(), place, inplace, False, False,
True)
self.check_input_and_optput(core.Scope(), place, inplace, False, False,
False)
def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32")
for i in range(row_num):
array[i] *= i
return array
def check_input_and_optput(self, def check_input_and_optput(self,
scope, scope,
place, place,
inplace,
w1_has_data=False, w1_has_data=False,
w2_has_data=False, w2_has_data=False,
w3_has_data=False): w3_has_data=False):
...@@ -64,35 +78,43 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -64,35 +78,43 @@ class TestSelectedRowsSumOp(OpTest):
self.create_selected_rows(scope, place, "W3", w3_has_data) self.create_selected_rows(scope, place, "W3", w3_has_data)
# create Out Variable # create Out Variable
out = scope.var('Out').get_selected_rows() if inplace:
out_var_name = "W1"
else:
out_var_name = "Out"
out = scope.var(out_var_name).get_selected_rows()
# create and run sum operator # create and run sum operator
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') sum_op = Operator("sum", X=["W1", "W2", "W3"], Out=out_var_name)
sum_op.run(scope, place) sum_op.run(scope, place)
has_data_w_num = 0 has_data_w_num = 0
for w in [w1_has_data, w2_has_data, w3_has_data]: for has_data in [w1_has_data, w2_has_data, w3_has_data]:
if not w: if has_data:
has_data_w_num += 1 has_data_w_num += 1
self.assertEqual(7 * has_data_w_num, len(out.rows())) if has_data_w_num > 0:
self.assertEqual(len(out.rows()), 7)
self.assertTrue(
np.array_equal(
np.array(out.get_tensor()),
self._get_array(len(self.rows), self.row_numel) *
has_data_w_num))
else:
self.assertEqual(len(out.rows()), 0)
def create_selected_rows(self, scope, place, var_name, isEmpty): def create_selected_rows(self, scope, place, var_name, has_data):
# create and initialize W Variable # create and initialize W Variable
if not isEmpty: if has_data:
rows = [0, 1, 2, 3, 4, 5, 6] rows = self.rows
row_numel = 12
else: else:
rows = [] rows = []
row_numel = 12
var = scope.var(var_name) var = scope.var(var_name)
w_selected_rows = var.get_selected_rows() w_selected_rows = var.get_selected_rows()
w_selected_rows.set_height(len(rows)) w_selected_rows.set_height(self.height)
w_selected_rows.set_rows(rows) w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32") w_array = self._get_array(len(rows), self.row_numel)
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor() w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place) w_tensor.set(w_array, place)
...@@ -100,9 +122,11 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -100,9 +122,11 @@ class TestSelectedRowsSumOp(OpTest):
def test_w_is_selected_rows(self): def test_w_is_selected_rows(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
# currently only support CPU if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places: for place in places:
self.check_with_place(place) for inplace in [True, False]:
self.check_with_place(place, inplace)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -475,6 +475,26 @@ class DistributeTranspiler(object): ...@@ -475,6 +475,26 @@ class DistributeTranspiler(object):
delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
delete_ops(self.origin_program.global_block(), lr_ops) delete_ops(self.origin_program.global_block(), lr_ops)
# delete table init op
if self.has_distributed_lookup_table:
table_var = self.startup_program.global_block().vars[
self.table_name]
table_param_init_op = []
for op in self.startup_program.global_block().ops:
if self.table_name in op.output_arg_names:
table_param_init_op.append(op)
init_op_num = len(table_param_init_op)
if init_op_num != 1:
raise ValueError("table init op num should be 1, now is " + str(
init_op_num))
table_init_op = table_param_init_op[0]
self.startup_program.global_block().append_op(
type="fake_init",
inputs={},
outputs={"Out": table_var},
attrs={"shape": table_init_op.attr('shape')})
delete_ops(self.startup_program.global_block(), table_param_init_op)
self.origin_program.__str__() self.origin_program.__str__()
if wait_port: if wait_port:
...@@ -1034,15 +1054,11 @@ to transpile() call.") ...@@ -1034,15 +1054,11 @@ to transpile() call.")
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
# self.all_prefetch_input_vars = self.all_in_ids_vars = []
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = [] self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = [] self.all_prefetch_output_vars = []
self.all_out_emb_vars = []
lookup_table_op_index = -1
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
while continue_search_lookup_table_op: while continue_search_lookup_table_op:
...@@ -1052,42 +1068,50 @@ to transpile() call.") ...@@ -1052,42 +1068,50 @@ to transpile() call.")
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
lookup_table_op_index = list(all_ops).index(op) lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(
all_ops).index(op)
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
ids_var = program.global_block().vars[ids_name[0]] ids_var = program.global_block().vars[ids_name[0]]
prefetch_input_vars = self._create_splited_vars( self.all_in_ids_vars.append(ids_var)
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
self.all_prefetch_input_vars.append(prefetch_input_vars)
out_var = program.global_block().vars[out_name[0]] out_var = program.global_block().vars[out_name[0]]
prefetch_output_vars = self._create_splited_vars( self.all_out_emb_vars.append(out_var)
source_var=out_var,
block=program.global_block(), # delete lookup_table_op
tag="_prefetch_out_") delete_ops(program.global_block(), [op])
self.all_prefetch_output_vars.append(prefetch_output_vars) # break for loop
break
for index in range(len(self.pserver_endpoints)):
in_var = program.global_block().create_var(
name=str("prefetch_compress_in_tmp_" + str(index)),
type=self.all_in_ids_vars[0].type,
shape=self.all_in_ids_vars[0].shape,
dtype=self.all_in_ids_vars[0].dtype)
self.all_prefetch_input_vars.append(in_var)
out_var = program.global_block().create_var(
name=str("prefetch_compress_out_tmp_" + str(index)),
type=self.all_out_emb_vars[0].type,
shape=self.all_out_emb_vars[0].shape,
dtype=self.all_out_emb_vars[0].dtype)
self.all_prefetch_output_vars.append(out_var)
# insert split_ids_op # insert split_ids_op
program.global_block()._insert_op( program.global_block()._insert_op(
index=lookup_table_op_index, index=lookup_table_op_index,
type="split_ids", type="split_ids",
inputs={ inputs={'Ids': self.all_in_ids_vars},
'Ids': [ outputs={"Out": self.all_prefetch_input_vars})
program.global_block().vars[varname]
for varname in ids_name
]
},
outputs={"Out": prefetch_input_vars})
# insert prefetch_op # insert prefetch_op
program.global_block()._insert_op( program.global_block()._insert_op(
index=lookup_table_op_index + 1, index=lookup_table_op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': prefetch_input_vars}, inputs={'X': self.all_prefetch_input_vars},
outputs={"Out": prefetch_output_vars}, outputs={"Out": self.all_prefetch_output_vars},
attrs={ attrs={
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
# FIXME(qiao) temporarily disable this config because prefetch # FIXME(qiao) temporarily disable this config because prefetch
...@@ -1100,23 +1124,11 @@ to transpile() call.") ...@@ -1100,23 +1124,11 @@ to transpile() call.")
index=lookup_table_op_index + 2, index=lookup_table_op_index + 2,
type="merge_ids", type="merge_ids",
inputs={ inputs={
'Ids': [ 'Ids': self.all_in_ids_vars,
program.global_block().vars[varname] 'Rows': self.all_prefetch_input_vars,
for varname in ids_name 'X': self.all_prefetch_output_vars
],
'X': prefetch_output_vars
}, },
outputs={ outputs={"Out": self.all_out_emb_vars})
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
})
# delete lookup_table_op
delete_ops(program.global_block(), [op])
# break for loop
break
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_op to send gradient to pservers # 2. add split_ids_op and send_op to send gradient to pservers
...@@ -1134,7 +1146,8 @@ to transpile() call.") ...@@ -1134,7 +1146,8 @@ to transpile() call.")
inputs={ inputs={
'Ids': [program.global_block().vars[table_grad_name]] 'Ids': [program.global_block().vars[table_grad_name]]
}, },
outputs={"Out": self.trainer_side_table_grad_list}) outputs={"Out": self.trainer_side_table_grad_list},
attrs={RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE})
program.global_block()._insert_op( program.global_block()._insert_op(
index=op_index + 2, index=op_index + 2,
type="send", type="send",
...@@ -1160,15 +1173,14 @@ to transpile() call.") ...@@ -1160,15 +1173,14 @@ to transpile() call.")
# STEP: create prefetch block # STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name] table_var = pserver_program.global_block().vars[self.table_name]
prefetch_var_name_to_block_id = [] prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program._create_block(optimize_block.idx) prefetch_block = pserver_program._create_block(optimize_block.idx)
trainer_ids = self.all_prefetch_input_vars[index][pserver_index] trainer_ids = self.all_prefetch_input_vars[pserver_index]
pserver_ids = pserver_program.global_block().create_var( pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name, name=trainer_ids.name,
type=trainer_ids.type, type=trainer_ids.type,
shape=trainer_ids.shape, shape=trainer_ids.shape,
dtype=trainer_ids.dtype) dtype=trainer_ids.dtype)
trainer_out = self.all_prefetch_output_vars[index][pserver_index] trainer_out = self.all_prefetch_output_vars[pserver_index]
pserver_out = pserver_program.global_block().create_var( pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name, name=trainer_out.name,
type=trainer_out.type, type=trainer_out.type,
...@@ -1364,16 +1376,6 @@ to transpile() call.") ...@@ -1364,16 +1376,6 @@ to transpile() call.")
program.global_block()._sync_with_cpp() program.global_block()._sync_with_cpp()
return var_mapping return var_mapping
def _create_splited_vars(self, source_var, block, tag):
return [
block.create_var(
name=str(source_var.name + tag + str(index)),
type=source_var.type,
shape=source_var.shape,
dtype=source_var.dtype)
for index in range(len(self.pserver_endpoints))
]
def _clone_var(self, block, var, persistable=True): def _clone_var(self, block, var, persistable=True):
return block.create_var( return block.create_var(
name=var.name, name=var.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册