提交 cb284283 编写于 作者: D dangqingqing

Replace LoDTensor in elementwise_mul_op, pad_op and recurrent_op_utils.

上级 30a58b51
...@@ -189,13 +189,7 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -189,13 +189,7 @@ void OperatorBase::GenerateTemporaryNames() {
template <> template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const { const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
if (var == nullptr) return nullptr; return var == nullptr ? nullptr : GetTensorFromVar(var);
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
} }
template <> template <>
...@@ -204,9 +198,11 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>( ...@@ -204,9 +198,11 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
auto names = op().Inputs(name); auto names = op().Inputs(name);
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform( std::transform(names.begin(), names.end(), std::back_inserter(res),
names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) {
[&](const std::string& sub_name) { return Input<Tensor>(sub_name); }); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
});
return res; return res;
} }
...@@ -214,12 +210,7 @@ template <> ...@@ -214,12 +210,7 @@ template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name); auto* var = OutputVar(name);
if (var == nullptr) return nullptr; if (var == nullptr) return nullptr;
if (var->IsType<LoDTensor>()) { return GetTensorFromVar(var);
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>());
} }
template <> template <>
...@@ -228,9 +219,11 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -228,9 +219,11 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
auto names = op().Outputs(name); auto names = op().Outputs(name);
std::vector<Tensor*> res; std::vector<Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform( std::transform(names.begin(), names.end(), std::back_inserter(res),
names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) {
[&](const std::string& sub_name) { return Output<Tensor>(sub_name); }); auto var = scope().FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
});
return res; return res;
} }
......
...@@ -306,9 +306,11 @@ class InferShapeContext { ...@@ -306,9 +306,11 @@ class InferShapeContext {
auto names = op_.Inputs(name); auto names = op_.Inputs(name);
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform( std::transform(names.begin(), names.end(), std::back_inserter(res),
names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) {
[&](const std::string& sub_name) { return Input<T>(sub_name); }); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : &var->Get<T>();
});
return res; return res;
} }
...@@ -317,12 +319,23 @@ class InferShapeContext { ...@@ -317,12 +319,23 @@ class InferShapeContext {
auto names = op_.Outputs(name); auto names = op_.Outputs(name);
std::vector<T*> res; std::vector<T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform( std::transform(names.begin(), names.end(), std::back_inserter(res),
names.begin(), names.end(), std::back_inserter(res), [&](const std::string& sub_name) {
[&](const std::string& sub_name) { return Output<T>(sub_name); }); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->GetMutable<T>();
});
return res; return res;
} }
Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>());
}
private: private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
......
...@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel { ...@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel {
auto y_dim = ctx.Input<Tensor>("Y")->dims(); auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
ctx.Output<Tensor>("Out")->Resize(x_dim); ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
} }
}; };
...@@ -80,8 +80,8 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel { ...@@ -80,8 +80,8 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
......
...@@ -34,7 +34,8 @@ class PadOp : public framework::OperatorWithKernel { ...@@ -34,7 +34,8 @@ class PadOp : public framework::OperatorWithKernel {
for (int i = 0; i < x_dim.size(); ++i) { for (int i = 0; i < x_dim.size(); ++i) {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
} }
ctx.Output<Tensor>("Out")->Resize(framework::make_ddim(out_dims)); ctx.Output<framework::LoDTensor>("Out")->Resize(
framework::make_ddim(out_dims));
} }
}; };
...@@ -95,9 +96,9 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -95,9 +96,9 @@ class PadOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_g = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) { if (x_g != nullptr) {
x_grad->Resize(x_dims); x_g->Resize(x_dims);
} }
} }
}; };
......
...@@ -21,6 +21,7 @@ namespace rnn { ...@@ -21,6 +21,7 @@ namespace rnn {
namespace f = paddle::framework; namespace f = paddle::framework;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len, const std::vector<Link>& inlinks, const size_t seq_len,
...@@ -31,7 +32,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -31,7 +32,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.", PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
inlinks[i].external); inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>(); LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims(); f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length"); "all the inlinks must have same length");
...@@ -40,6 +41,8 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -40,6 +41,8 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
Tensor* step_input = Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>(); step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
if (!infer_shape_mode) { if (!infer_shape_mode) {
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1); *step_input = input->Slice<float>(j, j + 1);
} }
step_input->Resize(step_dims); step_input->Resize(step_dims);
...@@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
auto output_var = step_scopes[0]->FindVar(outlinks[i].external); auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.", PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
outlinks[i].external); outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>(); LoDTensor* output = output_var->GetMutable<LoDTensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal); auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal); outlinks[i].internal);
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims(); f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims); std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec)); output->Resize(f::make_ddim(dims_vec));
} else { } else {
output->mutable_data<float>(platform::CPUPlace()); output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output = LoDTensor* step_output = step_scopes[j]
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>(); ->FindVar(outlinks[i].internal)
->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set // TODO(luotao02) data type and platform::DeviceContext() should set
// correctly // correctly
(output->Slice<float>(j, j + 1)) (output->Slice<float>(j, j + 1))
...@@ -94,8 +99,8 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -94,8 +99,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
auto scope = scopes[step_id]; auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset]; auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>(); auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
mem->Resize(linked_mem->dims()); mem->Resize(linked_mem->dims());
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册