提交 36dce65b 编写于 作者: M minqiyang

Take DataType and VarType apart

test=develop
上级 db0c9708
...@@ -129,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -129,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](InferVarTypeContext& context) { info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference; T inference;
inference(context); inference(*context);
}; };
} }
}; };
......
...@@ -48,7 +48,7 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -48,7 +48,7 @@ class SumOpVarTypeInference : public VarTypeInference {
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string &name) { inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx.GetType(name) == proto::VarType::LOD_TENSOR; return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
......
...@@ -679,7 +679,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -679,7 +679,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
InferVarTypeContext context(this, block); InferVarTypeContext context(this, block);
info.infer_var_type_(context); info.infer_var_type_(&context);
} }
} }
......
...@@ -54,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>( ...@@ -54,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN = using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext& /*context*/)>; std::function<void(framework::InferVarTypeContext* /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
......
...@@ -49,7 +49,7 @@ class SumOpVarTypeInference : public VarTypeInference { ...@@ -49,7 +49,7 @@ class SumOpVarTypeInference : public VarTypeInference {
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string &name) { inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx.GetType(name) == proto::VarType::LOD_TENSOR; return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
......
...@@ -243,12 +243,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -243,12 +243,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = tmp_grad_outputs[k][it.first]; auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size()); outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
VarBase* origin_grad_var_base = it.second[i];
// Allocate a new variable // Allocate a new variable
Variable* tmp_var = new framework::Variable(); VarBase* tmp_grad_var_base = new VarBase(
tmp_var->GetMutable<framework::LoDTensor>(); string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
VarBase* tmp_var_base = origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
new VarBase(it.second[i]->Name(), tmp_var, nullptr, true); place_, true, false);
outputs.emplace_back(tmp_var_base); outputs.emplace_back(tmp_grad_var_base);
} }
} }
...@@ -259,13 +261,12 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -259,13 +261,12 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc); framework::OpRegistry::CreateOp(*grad_op_desc);
// auto& info = auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
// framework::OpInfoMap::Instance().Get(grad_op_desc->Type()); if (info.infer_var_type_) {
// if (info.infer_var_type_) { RuntimeInferVarTypeContext infer_var_type_ctx(
// framework::RuntimeInferVarTypeContext infer_var_type_ctx( &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
// this, &grad_inputs, &outputs, &attrs_map); info.infer_var_type_(&infer_var_type_ctx);
// info.infer_var_type_(infer_var_type_ctx); }
// }
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get()); dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
...@@ -298,7 +299,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -298,7 +299,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map); framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope; framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
......
...@@ -137,13 +137,13 @@ class VarBase { ...@@ -137,13 +137,13 @@ class VarBase {
persistable) {} persistable) {}
private: private:
// TODO(minqiyang): need support SelectedRows
VarBase(const std::string& name, framework::proto::VarType::Type dtype, VarBase(const std::string& name, framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place, const framework::DDim& shape, const platform::Place& place,
framework::Variable* var, VarBase* grad, bool stop_gradient, framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable) bool persistable)
: name_(name), : name_(name),
dtype_(dtype), type_(framework::proto::VarType::LOD_TENSOR),
place_(place),
var_(var), var_(var),
grads_(grad), grads_(grad),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
...@@ -153,10 +153,12 @@ class VarBase { ...@@ -153,10 +153,12 @@ class VarBase {
pre_op_out_idx_(-1) { pre_op_out_idx_(-1) {
if (!var_) { if (!var_) {
var_ = new framework::Variable(); var_ = new framework::Variable();
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
} }
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place, dtype);
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
} }
public: public:
...@@ -186,11 +188,23 @@ class VarBase { ...@@ -186,11 +188,23 @@ class VarBase {
} }
} }
inline void SetDType(framework::proto::VarType::Type type) { inline framework::DDim Dims() const {
return var_->Get<framework::LoDTensor>().dims();
}
// data type. e.g.. FP32
inline void SetDataType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>(); auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place_, dtype_); tensor->mutable_data(place_, type);
} }
inline framework::proto::VarType::Type DType() const { return dtype_; } inline framework::proto::VarType::Type DataType() const {
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
// tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
inline framework::proto::VarType::Type Type() const { return type_; }
inline void SetStopGradient(bool stop_gradient) { inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient; stop_gradient_ = stop_gradient;
...@@ -244,7 +258,7 @@ class VarBase { ...@@ -244,7 +258,7 @@ class VarBase {
} }
std::string name_; std::string name_;
framework::proto::VarType::Type dtype_; framework::proto::VarType::Type type_;
platform::Place place_; platform::Place place_;
framework::Variable* var_; framework::Variable* var_;
...@@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase { ...@@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase {
std::vector<VarBasePtrMap> grad_output_vars_; std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_; std::vector<py::object> backward_hooks_;
framework::AttributeMap attrs_;
}; };
class Layer { class Layer {
...@@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext ...@@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
framework::proto::VarType::Type GetType( framework::proto::VarType::Type GetType(
const std::string& name) const override { const std::string& name) const override {
return var_set_.at(name)->DType(); return var_set_.at(name)->Type();
} }
void SetType(const std::string& name, void SetType(const std::string& name,
framework::proto::VarType::Type type) override { framework::proto::VarType::Type type) override {
var_set_[name]->SetDType(type); var_set_[name]->SetType(type);
} }
framework::proto::VarType::Type GetDataType( framework::proto::VarType::Type GetDataType(
const std::string& name) const override { const std::string& name) const override {
return var_set_.at(name)->DType(); return var_set_.at(name)->DataType();
} }
void SetDataType(const std::string& name, void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override { framework::proto::VarType::Type type) override {
var_set_[name]->SetDType(type); var_set_[name]->SetDataType(type);
} }
std::vector<framework::proto::VarType::Type> GetDataTypes( std::vector<framework::proto::VarType::Type> GetDataTypes(
......
...@@ -232,7 +232,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -232,7 +232,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if (info.infer_var_type_) { if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs, RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs,
&attrs_map); &attrs_map);
info.infer_var_type_(infer_var_type_ctx); info.infer_var_type_(&infer_var_type_ctx);
} }
// TODO(minqiyang): Support infer var type in imperative mode // TODO(minqiyang): Support infer var type in imperative mode
...@@ -259,6 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -259,6 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VLOG(5) << "start construct backward op"; VLOG(5) << "start construct backward op";
// construct grad op descs // construct grad op descs
op->attrs_ = attrs_map;
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc( std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map)); op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
......
...@@ -168,11 +168,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -168,11 +168,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
} }
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string& name) { inputs.begin(), inputs.end(), [&ctx](const std::string& name) {
return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR; return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR;
}); });
auto is_tensor_array = [ctx](const std::string& name) { auto is_tensor_array = [&ctx](const std::string& name) {
return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY; return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
}; };
......
...@@ -194,7 +194,7 @@ PYBIND11_MODULE(core, m) { ...@@ -194,7 +194,7 @@ PYBIND11_MODULE(core, m) {
.def_property("name", &imperative::VarBase::Name, .def_property("name", &imperative::VarBase::Name,
&imperative::VarBase::SetName) &imperative::VarBase::SetName)
.def_property_readonly("shape", &imperative::VarBase::Shape) .def_property_readonly("shape", &imperative::VarBase::Shape)
.def_property_readonly("dtype", &imperative::VarBase::DType) .def_property_readonly("dtype", &imperative::VarBase::DataType)
.def_property("persistable", &imperative::VarBase::IsPersistable, .def_property("persistable", &imperative::VarBase::IsPersistable,
&imperative::VarBase::SetPersistable) &imperative::VarBase::SetPersistable)
.def_property("stop_gradient", &imperative::VarBase::IsStopGradient, .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册