提交 36dce65b 编写于 作者: M minqiyang

Take DataType and VarType apart

test=develop
上级 db0c9708
......@@ -129,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](InferVarTypeContext& context) {
info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference;
inference(context);
inference(*context);
};
}
};
......
......@@ -48,7 +48,7 @@ class SumOpVarTypeInference : public VarTypeInference {
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string &name) {
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
......
......@@ -679,7 +679,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) {
InferVarTypeContext context(this, block);
info.infer_var_type_(context);
info.infer_var_type_(&context);
}
}
......
......@@ -54,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext& /*context*/)>;
std::function<void(framework::InferVarTypeContext* /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>;
......
......@@ -49,7 +49,7 @@ class SumOpVarTypeInference : public VarTypeInference {
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string &name) {
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
......
......@@ -243,12 +243,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) {
VarBase* origin_grad_var_base = it.second[i];
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
VarBase* tmp_var_base =
new VarBase(it.second[i]->Name(), tmp_var, nullptr, true);
outputs.emplace_back(tmp_var_base);
VarBase* tmp_grad_var_base = new VarBase(
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
place_, true, false);
outputs.emplace_back(tmp_grad_var_base);
}
}
......@@ -259,13 +261,12 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc);
// auto& info =
// framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
// if (info.infer_var_type_) {
// framework::RuntimeInferVarTypeContext infer_var_type_ctx(
// this, &grad_inputs, &outputs, &attrs_map);
// info.infer_var_type_(infer_var_type_ctx);
// }
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
info.infer_var_type_(&infer_var_type_ctx);
}
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
......@@ -298,7 +299,6 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
......
......@@ -137,13 +137,13 @@ class VarBase {
persistable) {}
private:
// TODO(minqiyang): need support SelectedRows
VarBase(const std::string& name, framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place,
framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable)
: name_(name),
dtype_(dtype),
place_(place),
type_(framework::proto::VarType::LOD_TENSOR),
var_(var),
grads_(grad),
stop_gradient_(stop_gradient),
......@@ -153,10 +153,12 @@ class VarBase {
pre_op_out_idx_(-1) {
if (!var_) {
var_ = new framework::Variable();
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
}
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place, dtype);
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
}
public:
......@@ -186,11 +188,23 @@ class VarBase {
}
}
inline void SetDType(framework::proto::VarType::Type type) {
inline framework::DDim Dims() const {
return var_->Get<framework::LoDTensor>().dims();
}
// data type. e.g.. FP32
inline void SetDataType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place_, dtype_);
tensor->mutable_data(place_, type);
}
inline framework::proto::VarType::Type DType() const { return dtype_; }
inline framework::proto::VarType::Type DataType() const {
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
// tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
inline framework::proto::VarType::Type Type() const { return type_; }
inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient;
......@@ -244,7 +258,7 @@ class VarBase {
}
std::string name_;
framework::proto::VarType::Type dtype_;
framework::proto::VarType::Type type_;
platform::Place place_;
framework::Variable* var_;
......@@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase {
std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_;
framework::AttributeMap attrs_;
};
class Layer {
......@@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext
framework::proto::VarType::Type GetType(
const std::string& name) const override {
return var_set_.at(name)->DType();
return var_set_.at(name)->Type();
}
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDType(type);
var_set_[name]->SetType(type);
}
framework::proto::VarType::Type GetDataType(
const std::string& name) const override {
return var_set_.at(name)->DType();
return var_set_.at(name)->DataType();
}
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDType(type);
var_set_[name]->SetDataType(type);
}
std::vector<framework::proto::VarType::Type> GetDataTypes(
......
......@@ -232,7 +232,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs,
&attrs_map);
info.infer_var_type_(infer_var_type_ctx);
info.infer_var_type_(&infer_var_type_ctx);
}
// TODO(minqiyang): Support infer var type in imperative mode
......@@ -259,6 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VLOG(5) << "start construct backward op";
// construct grad op descs
op->attrs_ = attrs_map;
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
......
......@@ -168,11 +168,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
}
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string& name) {
inputs.begin(), inputs.end(), [&ctx](const std::string& name) {
return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR;
});
auto is_tensor_array = [ctx](const std::string& name) {
auto is_tensor_array = [&ctx](const std::string& name) {
return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
};
......
......@@ -194,7 +194,7 @@ PYBIND11_MODULE(core, m) {
.def_property("name", &imperative::VarBase::Name,
&imperative::VarBase::SetName)
.def_property_readonly("shape", &imperative::VarBase::Shape)
.def_property_readonly("dtype", &imperative::VarBase::DType)
.def_property_readonly("dtype", &imperative::VarBase::DataType)
.def_property("persistable", &imperative::VarBase::IsPersistable,
&imperative::VarBase::SetPersistable)
.def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册